Skip to content

Commit 081a66f

Browse files
authored
[DXIL] implement dot intrinsic lowering for integers (#85662)
this implements part 1 of 2 for #83626 - `CGBuiltin.cpp` - modified to have seperate cases for signed and unsigned integers. - `SemaChecking.cpp` - modified to prevent the generation of a double dot product intrinsic if the builtin were to be called directly. - `IntrinsicsDirectX.td` creation of the signed and unsigned dot intrinsics needed for instruction expansion. - `DXILIntrinsicExpansion.cpp` - handle instruction expansion cases for integer dot product.
1 parent 0081ec1 commit 081a66f

File tree

7 files changed

+200
-20
lines changed

7 files changed

+200
-20
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18036,6 +18036,17 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
1803618036
return Arg;
1803718037
}
1803818038

18039+
Intrinsic::ID getDotProductIntrinsic(QualType QT) {
18040+
if (QT->hasSignedIntegerRepresentation())
18041+
return Intrinsic::dx_sdot;
18042+
if (QT->hasUnsignedIntegerRepresentation())
18043+
return Intrinsic::dx_udot;
18044+
18045+
assert(QT->hasFloatingRepresentation());
18046+
return Intrinsic::dx_dot;
18047+
;
18048+
}
18049+
1803918050
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1804018051
const CallExpr *E) {
1804118052
if (!getLangOpts().HLSL)
@@ -18096,7 +18107,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1809618107
"Dot product requires vectors to be of the same size.");
1809718108

1809818109
return Builder.CreateIntrinsic(
18099-
/*ReturnType=*/T0->getScalarType(), Intrinsic::dx_dot,
18110+
/*ReturnType=*/T0->getScalarType(),
18111+
getDotProductIntrinsic(E->getArg(0)->getType()),
1810018112
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
1810118113
} break;
1810218114
case Builtin::BI__builtin_hlsl_lerp: {

clang/lib/Sema/SemaChecking.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5484,6 +5484,18 @@ bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
54845484
checkFloatorHalf);
54855485
}
54865486

5487+
bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
5488+
auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
5489+
if (const auto *VecTy = dyn_cast<VectorType>(PassedType)) {
5490+
clang::QualType BaseType = VecTy->getElementType();
5491+
return !BaseType->isHalfType() && !BaseType->isFloat32Type();
5492+
}
5493+
return false;
5494+
};
5495+
return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
5496+
checkDoubleVector);
5497+
}
5498+
54875499
void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
54885500
QualType ReturnType) {
54895501
auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>();
@@ -5520,6 +5532,8 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
55205532
return true;
55215533
if (SemaBuiltinVectorToScalarMath(TheCall))
55225534
return true;
5535+
if (CheckNoDoubleVectors(this, TheCall))
5536+
return true;
55235537
break;
55245538
}
55255539
case Builtin::BI__builtin_hlsl_elementwise_rcp: {

clang/test/CodeGenHLSL/builtins/dot.hlsl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,31 @@
1111
// NATIVE_HALF: ret i16 %dx.dot
1212
int16_t test_dot_short(int16_t p0, int16_t p1) { return dot(p0, p1); }
1313

14-
// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
14+
// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v2i16(<2 x i16> %0, <2 x i16> %1)
1515
// NATIVE_HALF: ret i16 %dx.dot
1616
int16_t test_dot_short2(int16_t2 p0, int16_t2 p1) { return dot(p0, p1); }
1717

18-
// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
18+
// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v3i16(<3 x i16> %0, <3 x i16> %1)
1919
// NATIVE_HALF: ret i16 %dx.dot
2020
int16_t test_dot_short3(int16_t3 p0, int16_t3 p1) { return dot(p0, p1); }
2121

22-
// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
22+
// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.sdot.v4i16(<4 x i16> %0, <4 x i16> %1)
2323
// NATIVE_HALF: ret i16 %dx.dot
2424
int16_t test_dot_short4(int16_t4 p0, int16_t4 p1) { return dot(p0, p1); }
2525

2626
// NATIVE_HALF: %dx.dot = mul i16 %0, %1
2727
// NATIVE_HALF: ret i16 %dx.dot
2828
uint16_t test_dot_ushort(uint16_t p0, uint16_t p1) { return dot(p0, p1); }
2929

30-
// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
30+
// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v2i16(<2 x i16> %0, <2 x i16> %1)
3131
// NATIVE_HALF: ret i16 %dx.dot
3232
uint16_t test_dot_ushort2(uint16_t2 p0, uint16_t2 p1) { return dot(p0, p1); }
3333

34-
// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
34+
// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %0, <3 x i16> %1)
3535
// NATIVE_HALF: ret i16 %dx.dot
3636
uint16_t test_dot_ushort3(uint16_t3 p0, uint16_t3 p1) { return dot(p0, p1); }
3737

38-
// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
38+
// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.udot.v4i16(<4 x i16> %0, <4 x i16> %1)
3939
// NATIVE_HALF: ret i16 %dx.dot
4040
uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
4141
#endif
@@ -44,63 +44,63 @@ uint16_t test_dot_ushort4(uint16_t4 p0, uint16_t4 p1) { return dot(p0, p1); }
4444
// CHECK: ret i32 %dx.dot
4545
int test_dot_int(int p0, int p1) { return dot(p0, p1); }
4646

47-
// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
47+
// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v2i32(<2 x i32> %0, <2 x i32> %1)
4848
// CHECK: ret i32 %dx.dot
4949
int test_dot_int2(int2 p0, int2 p1) { return dot(p0, p1); }
5050

51-
// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
51+
// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v3i32(<3 x i32> %0, <3 x i32> %1)
5252
// CHECK: ret i32 %dx.dot
5353
int test_dot_int3(int3 p0, int3 p1) { return dot(p0, p1); }
5454

55-
// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
55+
// CHECK: %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %0, <4 x i32> %1)
5656
// CHECK: ret i32 %dx.dot
5757
int test_dot_int4(int4 p0, int4 p1) { return dot(p0, p1); }
5858

5959
// CHECK: %dx.dot = mul i32 %0, %1
6060
// CHECK: ret i32 %dx.dot
6161
uint test_dot_uint(uint p0, uint p1) { return dot(p0, p1); }
6262

63-
// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
63+
// CHECK: %dx.dot = call i32 @llvm.dx.udot.v2i32(<2 x i32> %0, <2 x i32> %1)
6464
// CHECK: ret i32 %dx.dot
6565
uint test_dot_uint2(uint2 p0, uint2 p1) { return dot(p0, p1); }
6666

67-
// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
67+
// CHECK: %dx.dot = call i32 @llvm.dx.udot.v3i32(<3 x i32> %0, <3 x i32> %1)
6868
// CHECK: ret i32 %dx.dot
6969
uint test_dot_uint3(uint3 p0, uint3 p1) { return dot(p0, p1); }
7070

71-
// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
71+
// CHECK: %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %0, <4 x i32> %1)
7272
// CHECK: ret i32 %dx.dot
7373
uint test_dot_uint4(uint4 p0, uint4 p1) { return dot(p0, p1); }
7474

7575
// CHECK: %dx.dot = mul i64 %0, %1
7676
// CHECK: ret i64 %dx.dot
7777
int64_t test_dot_long(int64_t p0, int64_t p1) { return dot(p0, p1); }
7878

79-
// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
79+
// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v2i64(<2 x i64> %0, <2 x i64> %1)
8080
// CHECK: ret i64 %dx.dot
8181
int64_t test_dot_long2(int64_t2 p0, int64_t2 p1) { return dot(p0, p1); }
8282

83-
// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
83+
// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v3i64(<3 x i64> %0, <3 x i64> %1)
8484
// CHECK: ret i64 %dx.dot
8585
int64_t test_dot_long3(int64_t3 p0, int64_t3 p1) { return dot(p0, p1); }
8686

87-
// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
87+
// CHECK: %dx.dot = call i64 @llvm.dx.sdot.v4i64(<4 x i64> %0, <4 x i64> %1)
8888
// CHECK: ret i64 %dx.dot
8989
int64_t test_dot_long4(int64_t4 p0, int64_t4 p1) { return dot(p0, p1); }
9090

9191
// CHECK: %dx.dot = mul i64 %0, %1
9292
// CHECK: ret i64 %dx.dot
9393
uint64_t test_dot_ulong(uint64_t p0, uint64_t p1) { return dot(p0, p1); }
9494

95-
// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
95+
// CHECK: %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %0, <2 x i64> %1)
9696
// CHECK: ret i64 %dx.dot
9797
uint64_t test_dot_ulong2(uint64_t2 p0, uint64_t2 p1) { return dot(p0, p1); }
9898

99-
// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
99+
// CHECK: %dx.dot = call i64 @llvm.dx.udot.v3i64(<3 x i64> %0, <3 x i64> %1)
100100
// CHECK: ret i64 %dx.dot
101101
uint64_t test_dot_ulong3(uint64_t3 p0, uint64_t3 p1) { return dot(p0, p1); }
102102

103-
// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
103+
// CHECK: %dx.dot = call i64 @llvm.dx.udot.v4i64(<4 x i64> %0, <4 x i64> %1)
104104
// CHECK: ret i64 %dx.dot
105105
uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
106106

clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,12 @@ int test_builtin_dot_bool_type_promotion(bool p0, bool p1) {
108108
return __builtin_hlsl_dot(p0, p1);
109109
// expected-error@-1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
110110
}
111+
112+
double test_dot_double(double2 p0, double2 p1) {
113+
return dot(p0, p1);
114+
// expected-error@-1 {{call to 'dot' is ambiguous}}
115+
}
116+
double test_dot_double_builtin(double2 p0, double2 p1) {
117+
return __builtin_hlsl_dot(p0, p1);
118+
// expected-error@-1 {{passing 'double2' (aka 'vector<double, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
119+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,18 @@ def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">,
2323
def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
2424
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
2525
def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
26+
2627
def int_dx_dot :
2728
Intrinsic<[LLVMVectorElementType<0>],
28-
[llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
29+
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
30+
[IntrNoMem, IntrWillReturn, Commutative] >;
31+
def int_dx_sdot :
32+
Intrinsic<[LLVMVectorElementType<0>],
33+
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
34+
[IntrNoMem, IntrWillReturn, Commutative] >;
35+
def int_dx_udot :
36+
Intrinsic<[LLVMVectorElementType<0>],
37+
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
2938
[IntrNoMem, IntrWillReturn, Commutative] >;
3039

3140
def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;

llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,44 @@ static bool isIntrinsicExpansion(Function &F) {
3939
case Intrinsic::dx_uclamp:
4040
case Intrinsic::dx_lerp:
4141
case Intrinsic::dx_rcp:
42+
case Intrinsic::dx_sdot:
43+
case Intrinsic::dx_udot:
4244
return true;
4345
}
4446
return false;
4547
}
4648

49+
static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
50+
assert(DotIntrinsic == Intrinsic::dx_sdot ||
51+
DotIntrinsic == Intrinsic::dx_udot);
52+
Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
53+
? Intrinsic::dx_imad
54+
: Intrinsic::dx_umad;
55+
Value *A = Orig->getOperand(0);
56+
Value *B = Orig->getOperand(1);
57+
Type *ATy = A->getType();
58+
Type *BTy = B->getType();
59+
assert(ATy->isVectorTy() && BTy->isVectorTy());
60+
61+
IRBuilder<> Builder(Orig->getParent());
62+
Builder.SetInsertPoint(Orig);
63+
64+
auto *AVec = dyn_cast<FixedVectorType>(A->getType());
65+
Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
66+
Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
67+
Value *Result = Builder.CreateMul(Elt0, Elt1);
68+
for (unsigned I = 1; I < AVec->getNumElements(); I++) {
69+
Elt0 = Builder.CreateExtractElement(A, I);
70+
Elt1 = Builder.CreateExtractElement(B, I);
71+
Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
72+
ArrayRef<Value *>{Elt0, Elt1, Result},
73+
nullptr, "dx.mad");
74+
}
75+
Orig->replaceAllUsesWith(Result);
76+
Orig->eraseFromParent();
77+
return true;
78+
}
79+
4780
static bool expandExpIntrinsic(CallInst *Orig) {
4881
Value *X = Orig->getOperand(0);
4982
IRBuilder<> Builder(Orig->getParent());
@@ -191,6 +224,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
191224
return expandLerpIntrinsic(Orig);
192225
case Intrinsic::dx_rcp:
193226
return expandRcpIntrinsic(Orig);
227+
case Intrinsic::dx_sdot:
228+
case Intrinsic::dx_udot:
229+
return expandIntegerDot(Orig, F.getIntrinsicID());
194230
}
195231
return false;
196232
}

llvm/test/CodeGen/DirectX/idot.ll

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
2+
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
3+
4+
; Make sure dxil operation function calls for dot are generated for int/uint vectors.
5+
6+
; CHECK-LABEL: dot_int16_t2
7+
define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) {
8+
entry:
9+
; CHECK: extractelement <2 x i16> %a, i64 0
10+
; CHECK: extractelement <2 x i16> %b, i64 0
11+
; CHECK: mul i16 %{{.*}}, %{{.*}}
12+
; CHECK: extractelement <2 x i16> %a, i64 1
13+
; CHECK: extractelement <2 x i16> %b, i64 1
14+
; EXPCHECK: call i16 @llvm.dx.imad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
15+
; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 48, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
16+
%dx.dot = call i16 @llvm.dx.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
17+
ret i16 %dx.dot
18+
}
19+
20+
; CHECK-LABEL: sdot_int4
21+
define noundef i32 @sdot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
22+
entry:
23+
; CHECK: extractelement <4 x i32> %a, i64 0
24+
; CHECK: extractelement <4 x i32> %b, i64 0
25+
; CHECK: mul i32 %{{.*}}, %{{.*}}
26+
; CHECK: extractelement <4 x i32> %a, i64 1
27+
; CHECK: extractelement <4 x i32> %b, i64 1
28+
; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
29+
; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
30+
; CHECK: extractelement <4 x i32> %a, i64 2
31+
; CHECK: extractelement <4 x i32> %b, i64 2
32+
; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
33+
; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
34+
; CHECK: extractelement <4 x i32> %a, i64 3
35+
; CHECK: extractelement <4 x i32> %b, i64 3
36+
; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
37+
; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
38+
%dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
39+
ret i32 %dx.dot
40+
}
41+
42+
; CHECK-LABEL: dot_uint16_t3
43+
define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) {
44+
entry:
45+
; CHECK: extractelement <3 x i16> %a, i64 0
46+
; CHECK: extractelement <3 x i16> %b, i64 0
47+
; CHECK: mul i16 %{{.*}}, %{{.*}}
48+
; CHECK: extractelement <3 x i16> %a, i64 1
49+
; CHECK: extractelement <3 x i16> %b, i64 1
50+
; EXPCHECK: call i16 @llvm.dx.umad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
51+
; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
52+
; CHECK: extractelement <3 x i16> %a, i64 2
53+
; CHECK: extractelement <3 x i16> %b, i64 2
54+
; EXPCHECK: call i16 @llvm.dx.umad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
55+
; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
56+
%dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
57+
ret i16 %dx.dot
58+
}
59+
60+
; CHECK-LABEL: dot_uint4
61+
define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
62+
entry:
63+
; CHECK: extractelement <4 x i32> %a, i64 0
64+
; CHECK: extractelement <4 x i32> %b, i64 0
65+
; CHECK: mul i32 %{{.*}}, %{{.*}}
66+
; CHECK: extractelement <4 x i32> %a, i64 1
67+
; CHECK: extractelement <4 x i32> %b, i64 1
68+
; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
69+
; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
70+
; CHECK: extractelement <4 x i32> %a, i64 2
71+
; CHECK: extractelement <4 x i32> %b, i64 2
72+
; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
73+
; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
74+
; CHECK: extractelement <4 x i32> %a, i64 3
75+
; CHECK: extractelement <4 x i32> %b, i64 3
76+
; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
77+
; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
78+
%dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
79+
ret i32 %dx.dot
80+
}
81+
82+
; CHECK-LABEL: dot_uint64_t4
83+
define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) {
84+
entry:
85+
; CHECK: extractelement <2 x i64> %a, i64 0
86+
; CHECK: extractelement <2 x i64> %b, i64 0
87+
; CHECK: mul i64 %{{.*}}, %{{.*}}
88+
; CHECK: extractelement <2 x i64> %a, i64 1
89+
; CHECK: extractelement <2 x i64> %b, i64 1
90+
; EXPCHECK: call i64 @llvm.dx.umad.i64(i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
91+
; DOPCHECK: call i64 @dx.op.tertiary.i64(i32 49, i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
92+
%dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
93+
ret i64 %dx.dot
94+
}
95+
96+
declare i16 @llvm.dx.sdot.v2i16(<2 x i16>, <2 x i16>)
97+
declare i32 @llvm.dx.sdot.v4i32(<4 x i32>, <4 x i32>)
98+
declare i16 @llvm.dx.udot.v3i32(<3 x i16>, <3 x i16>)
99+
declare i32 @llvm.dx.udot.v4i32(<4 x i32>, <4 x i32>)
100+
declare i64 @llvm.dx.udot.v2i64(<2 x i64>, <2 x i64>)

0 commit comments

Comments
 (0)