Skip to content

Commit 060df78

Browse files
authored
[DXIL] Add Float Dot Intrinsic Lowering (#86071)
Completes #83626 - `CGBuiltin.cpp` - modify `getDotProductIntrinsic` to be able to emit `dot2`, `dot3`, and `dot4` intrinsics based on element count - `IntrinsicsDirectX.td` - for floating point add `dot2`, `dot3`, and `dot4` inntrinsics -`DXIL.td` add dxilop intrinsic lowering for `dot2`, `dot3`, & `dot4`. - `DXILOpLowering.cpp` - add vector arg flattening for dot product. - `DXILOpBuilder.h` - modify `createDXILOpCall` to take a smallVector instead of an iterator - `DXILOpBuilder.cpp` - modify `createDXILOpCall` by moving the small vector up to the calling function in `DXILOpLowering.cpp`. - Moving one function up gives us access to the `CallInst` and `Function` which were needed to distinguish the dot product intrinsics and get the operands without using the iterator.
1 parent 765d4c4 commit 060df78

File tree

11 files changed

+230
-34
lines changed

11 files changed

+230
-34
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18066,15 +18066,22 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
1806618066
return Arg;
1806718067
}
1806818068

18069-
Intrinsic::ID getDotProductIntrinsic(QualType QT) {
18069+
Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) {
18070+
if (QT->hasFloatingRepresentation()) {
18071+
switch (elementCount) {
18072+
case 2:
18073+
return Intrinsic::dx_dot2;
18074+
case 3:
18075+
return Intrinsic::dx_dot3;
18076+
case 4:
18077+
return Intrinsic::dx_dot4;
18078+
}
18079+
}
1807018080
if (QT->hasSignedIntegerRepresentation())
1807118081
return Intrinsic::dx_sdot;
18072-
if (QT->hasUnsignedIntegerRepresentation())
18073-
return Intrinsic::dx_udot;
1807418082

18075-
assert(QT->hasFloatingRepresentation());
18076-
return Intrinsic::dx_dot;
18077-
;
18083+
assert(QT->hasUnsignedIntegerRepresentation());
18084+
return Intrinsic::dx_udot;
1807818085
}
1807918086

1808018087
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
@@ -18128,8 +18135,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1812818135
assert(T0->getScalarType() == T1->getScalarType() &&
1812918136
"Dot product of vectors need the same element types.");
1813018137

18131-
[[maybe_unused]] auto *VecTy0 =
18132-
E->getArg(0)->getType()->getAs<VectorType>();
18138+
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
1813318139
[[maybe_unused]] auto *VecTy1 =
1813418140
E->getArg(1)->getType()->getAs<VectorType>();
1813518141
// A HLSLVectorTruncation should have happend
@@ -18138,7 +18144,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1813818144

1813918145
return Builder.CreateIntrinsic(
1814018146
/*ReturnType=*/T0->getScalarType(),
18141-
getDotProductIntrinsic(E->getArg(0)->getType()),
18147+
getDotProductIntrinsic(E->getArg(0)->getType(),
18148+
VecTy0->getNumElements()),
1814218149
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
1814318150
} break;
1814418151
case Builtin::BI__builtin_hlsl_lerp: {

clang/test/CodeGenHLSL/builtins/dot.hlsl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -110,56 +110,56 @@ uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
110110
// NO_HALF: ret float %dx.dot
111111
half test_dot_half(half p0, half p1) { return dot(p0, p1); }
112112

113-
// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1)
113+
// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1)
114114
// NATIVE_HALF: ret half %dx.dot
115-
// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
115+
// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
116116
// NO_HALF: ret float %dx.dot
117117
half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }
118118

119-
// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1)
119+
// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1)
120120
// NATIVE_HALF: ret half %dx.dot
121-
// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
121+
// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
122122
// NO_HALF: ret float %dx.dot
123123
half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }
124124

125-
// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1)
125+
// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1)
126126
// NATIVE_HALF: ret half %dx.dot
127-
// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
127+
// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
128128
// NO_HALF: ret float %dx.dot
129129
half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }
130130

131131
// CHECK: %dx.dot = fmul float %0, %1
132132
// CHECK: ret float %dx.dot
133133
float test_dot_float(float p0, float p1) { return dot(p0, p1); }
134134

135-
// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
135+
// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
136136
// CHECK: ret float %dx.dot
137137
float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }
138138

139-
// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
139+
// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
140140
// CHECK: ret float %dx.dot
141141
float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
142142

143-
// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
143+
// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
144144
// CHECK: ret float %dx.dot
145145
float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
146146

147-
// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
147+
// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
148148
// CHECK: ret float %dx.dot
149149
float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
150150

151-
// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
151+
// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
152152
// CHECK: ret float %dx.dot
153153
float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
154154

155-
// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
155+
// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
156156
// CHECK: ret float %dx.dot
157157
float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
158158

159159
// CHECK: %conv = sitofp i32 %1 to float
160160
// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
161161
// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
162-
// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat)
162+
// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %splat.splat)
163163
// CHECK: ret float %dx.dot
164164
float test_builtin_dot_float2_int_splat(float2 p0, int p1) {
165165
return dot(p0, p1);
@@ -168,7 +168,7 @@ float test_builtin_dot_float2_int_splat(float2 p0, int p1) {
168168
// CHECK: %conv = sitofp i32 %1 to float
169169
// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
170170
// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
171-
// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat)
171+
// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %splat.splat)
172172
// CHECK: ret float %dx.dot
173173
float test_builtin_dot_float3_int_splat(float3 p0, int p1) {
174174
return dot(p0, p1);

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,15 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
2424
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
2525
def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
2626

27-
def int_dx_dot :
27+
def int_dx_dot2 :
28+
Intrinsic<[LLVMVectorElementType<0>],
29+
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
30+
[IntrNoMem, IntrWillReturn, Commutative] >;
31+
def int_dx_dot3 :
32+
Intrinsic<[LLVMVectorElementType<0>],
33+
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
34+
[IntrNoMem, IntrWillReturn, Commutative] >;
35+
def int_dx_dot4 :
2836
Intrinsic<[LLVMVectorElementType<0>],
2937
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
3038
[IntrNoMem, IntrWillReturn, Commutative] >;

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,15 @@ def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
303303
"Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
304304
def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
305305
"Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
306+
let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)) in
307+
def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
308+
"dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1">;
309+
let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)) in
310+
def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
311+
"dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2">;
312+
let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)) in
313+
def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
314+
"dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3">;
306315
def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
307316
"Reads the thread ID">;
308317
def GroupId : DXILOpMapping<94, groupId, int_dx_group_id,

llvm/lib/Target/DirectX/DXILOpBuilder.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ namespace dxil {
254254

255255
CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
256256
Type *OverloadTy,
257-
llvm::iterator_range<Use *> Args) {
257+
SmallVector<Value *> Args) {
258258
const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
259259

260260
OverloadKind Kind = getOverloadKind(OverloadTy);
@@ -272,10 +272,8 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
272272
FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
273273
DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
274274
}
275-
SmallVector<Value *> FullArgs;
276-
FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
277-
FullArgs.append(Args.begin(), Args.end());
278-
return B.CreateCall(DXILFn, FullArgs);
275+
276+
return B.CreateCall(DXILFn, Args);
279277
}
280278

281279
Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {

llvm/lib/Target/DirectX/DXILOpBuilder.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#define LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H
1414

1515
#include "DXILConstants.h"
16-
#include "llvm/ADT/iterator_range.h"
16+
#include "llvm/ADT/SmallVector.h"
1717

1818
namespace llvm {
1919
class Module;
@@ -35,8 +35,7 @@ class DXILOpBuilder {
3535
/// \param OverloadTy Overload type of the DXIL Op call constructed
3636
/// \return DXIL Op call constructed
3737
CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
38-
Type *OverloadTy,
39-
llvm::iterator_range<Use *> Args);
38+
Type *OverloadTy, SmallVector<Value *> Args);
4039
Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
4140
static const char *getOpCodeName(dxil::OpCode DXILOp);
4241

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,48 @@
3030
using namespace llvm;
3131
using namespace llvm::dxil;
3232

33+
static bool isVectorArgExpansion(Function &F) {
34+
switch (F.getIntrinsicID()) {
35+
case Intrinsic::dx_dot2:
36+
case Intrinsic::dx_dot3:
37+
case Intrinsic::dx_dot4:
38+
return true;
39+
}
40+
return false;
41+
}
42+
43+
static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
44+
SmallVector<Value *, 4> ExtractedElements;
45+
auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
46+
for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
47+
Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
48+
Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
49+
ExtractedElements.push_back(ExtractedElement);
50+
}
51+
return ExtractedElements;
52+
}
53+
54+
static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
55+
IRBuilder<> &Builder) {
56+
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
57+
unsigned NumOperands = Orig->getNumOperands() - 1;
58+
assert(NumOperands > 0);
59+
Value *Arg0 = Orig->getOperand(0);
60+
[[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
61+
assert(VecArg0);
62+
SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
63+
for (unsigned I = 1; I < NumOperands; ++I) {
64+
Value *Arg = Orig->getOperand(I);
65+
[[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
66+
assert(VecArg);
67+
assert(VecArg0->getElementType() == VecArg->getElementType());
68+
assert(VecArg0->getNumElements() == VecArg->getNumElements());
69+
auto NextOperandList = populateOperands(Arg, Builder);
70+
NewOperands.append(NextOperandList.begin(), NextOperandList.end());
71+
}
72+
return NewOperands;
73+
}
74+
3375
static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
3476
IRBuilder<> B(M.getContext());
3577
DXILOpBuilder DXILB(M, B);
@@ -39,9 +81,18 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
3981
if (!CI)
4082
continue;
4183

84+
SmallVector<Value *> Args;
85+
Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
86+
Args.emplace_back(DXILOpArg);
4287
B.SetInsertPoint(CI);
43-
CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(),
44-
OverloadTy, CI->args());
88+
if (isVectorArgExpansion(F)) {
89+
SmallVector<Value *> NewArgs = argVectorFlatten(CI, B);
90+
Args.append(NewArgs.begin(), NewArgs.end());
91+
} else
92+
Args.append(CI->arg_begin(), CI->arg_end());
93+
94+
CallInst *DXILCI =
95+
DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args);
4596

4697
CI->replaceAllUsesWith(DXILCI);
4798
CI->eraseFromParent();
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
2+
3+
; DXIL operation dot2 does not support double overload type
4+
; CHECK: LLVM ERROR: Invalid Overload
5+
6+
define noundef double @dot_double2(<2 x double> noundef %a, <2 x double> noundef %b) {
7+
entry:
8+
%dx.dot = call double @llvm.dx.dot2.v2f64(<2 x double> %a, <2 x double> %b)
9+
ret double %dx.dot
10+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
2+
3+
; DXIL operation dot3 does not support double overload type
4+
; CHECK: LLVM ERROR: Invalid Overload
5+
6+
define noundef double @dot_double3(<3 x double> noundef %a, <3 x double> noundef %b) {
7+
entry:
8+
%dx.dot = call double @llvm.dx.dot3.v3f64(<3 x double> %a, <3 x double> %b)
9+
ret double %dx.dot
10+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
2+
3+
; DXIL operation dot4 does not support double overload type
4+
; CHECK: LLVM ERROR: Invalid Overload
5+
6+
define noundef double @dot_double4(<4 x double> noundef %a, <4 x double> noundef %b) {
7+
entry:
8+
%dx.dot = call double @llvm.dx.dot4.v4f64(<4 x double> %a, <4 x double> %b)
9+
ret double %dx.dot
10+
}

llvm/test/CodeGen/DirectX/fdot.ll

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
2+
3+
; Make sure dxil operation function calls for dot are generated for int/uint vectors.
4+
5+
; CHECK-LABEL: dot_half2
6+
define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) {
7+
entry:
8+
; CHECK: extractelement <2 x half> %a, i32 0
9+
; CHECK: extractelement <2 x half> %a, i32 1
10+
; CHECK: extractelement <2 x half> %b, i32 0
11+
; CHECK: extractelement <2 x half> %b, i32 1
12+
; CHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
13+
%dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b)
14+
ret half %dx.dot
15+
}
16+
17+
; CHECK-LABEL: dot_half3
18+
define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) {
19+
entry:
20+
; CHECK: extractelement <3 x half> %a, i32 0
21+
; CHECK: extractelement <3 x half> %a, i32 1
22+
; CHECK: extractelement <3 x half> %a, i32 2
23+
; CHECK: extractelement <3 x half> %b, i32 0
24+
; CHECK: extractelement <3 x half> %b, i32 1
25+
; CHECK: extractelement <3 x half> %b, i32 2
26+
; CHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
27+
%dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b)
28+
ret half %dx.dot
29+
}
30+
31+
; CHECK-LABEL: dot_half4
32+
define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
33+
entry:
34+
; CHECK: extractelement <4 x half> %a, i32 0
35+
; CHECK: extractelement <4 x half> %a, i32 1
36+
; CHECK: extractelement <4 x half> %a, i32 2
37+
; CHECK: extractelement <4 x half> %a, i32 3
38+
; CHECK: extractelement <4 x half> %b, i32 0
39+
; CHECK: extractelement <4 x half> %b, i32 1
40+
; CHECK: extractelement <4 x half> %b, i32 2
41+
; CHECK: extractelement <4 x half> %b, i32 3
42+
; CHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
43+
%dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b)
44+
ret half %dx.dot
45+
}
46+
47+
; CHECK-LABEL: dot_float2
48+
define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) {
49+
entry:
50+
; CHECK: extractelement <2 x float> %a, i32 0
51+
; CHECK: extractelement <2 x float> %a, i32 1
52+
; CHECK: extractelement <2 x float> %b, i32 0
53+
; CHECK: extractelement <2 x float> %b, i32 1
54+
; CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
55+
%dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b)
56+
ret float %dx.dot
57+
}
58+
59+
; CHECK-LABEL: dot_float3
60+
define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) {
61+
entry:
62+
; CHECK: extractelement <3 x float> %a, i32 0
63+
; CHECK: extractelement <3 x float> %a, i32 1
64+
; CHECK: extractelement <3 x float> %a, i32 2
65+
; CHECK: extractelement <3 x float> %b, i32 0
66+
; CHECK: extractelement <3 x float> %b, i32 1
67+
; CHECK: extractelement <3 x float> %b, i32 2
68+
; CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
69+
%dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b)
70+
ret float %dx.dot
71+
}
72+
73+
; CHECK-LABEL: dot_float4
74+
define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
75+
entry:
76+
; CHECK: extractelement <4 x float> %a, i32 0
77+
; CHECK: extractelement <4 x float> %a, i32 1
78+
; CHECK: extractelement <4 x float> %a, i32 2
79+
; CHECK: extractelement <4 x float> %a, i32 3
80+
; CHECK: extractelement <4 x float> %b, i32 0
81+
; CHECK: extractelement <4 x float> %b, i32 1
82+
; CHECK: extractelement <4 x float> %b, i32 2
83+
; CHECK: extractelement <4 x float> %b, i32 3
84+
; CHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
85+
%dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b)
86+
ret float %dx.dot
87+
}
88+
89+
declare half @llvm.dx.dot.v2f16(<2 x half> , <2 x half> )
90+
declare half @llvm.dx.dot.v3f16(<3 x half> , <3 x half> )
91+
declare half @llvm.dx.dot.v4f16(<4 x half> , <4 x half> )
92+
declare float @llvm.dx.dot.v2f32(<2 x float>, <2 x float>)
93+
declare float @llvm.dx.dot.v3f32(<3 x float>, <3 x float>)
94+
declare float @llvm.dx.dot.v4f32(<4 x float>, <4 x float>)

0 commit comments

Comments
 (0)