Skip to content

[DXIL] Add Float Dot Intrinsic Lowering #86071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18066,15 +18066,22 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
return Arg;
}

Intrinsic::ID getDotProductIntrinsic(QualType QT) {
Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) {
if (QT->hasFloatingRepresentation()) {
switch (elementCount) {
case 2:
return Intrinsic::dx_dot2;
case 3:
return Intrinsic::dx_dot3;
case 4:
return Intrinsic::dx_dot4;
}
}
if (QT->hasSignedIntegerRepresentation())
return Intrinsic::dx_sdot;
if (QT->hasUnsignedIntegerRepresentation())
return Intrinsic::dx_udot;

assert(QT->hasFloatingRepresentation());
return Intrinsic::dx_dot;
;
assert(QT->hasUnsignedIntegerRepresentation());
return Intrinsic::dx_udot;
}

Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
Expand Down Expand Up @@ -18128,8 +18135,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
assert(T0->getScalarType() == T1->getScalarType() &&
"Dot product of vectors need the same element types.");

[[maybe_unused]] auto *VecTy0 =
E->getArg(0)->getType()->getAs<VectorType>();
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
[[maybe_unused]] auto *VecTy1 =
E->getArg(1)->getType()->getAs<VectorType>();
// A HLSLVectorTruncation should have happend
Expand All @@ -18138,7 +18144,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,

return Builder.CreateIntrinsic(
/*ReturnType=*/T0->getScalarType(),
getDotProductIntrinsic(E->getArg(0)->getType()),
getDotProductIntrinsic(E->getArg(0)->getType(),
VecTy0->getNumElements()),
ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
} break;
case Builtin::BI__builtin_hlsl_lerp: {
Expand Down
28 changes: 14 additions & 14 deletions clang/test/CodeGenHLSL/builtins/dot.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -110,56 +110,56 @@ uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); }
// NO_HALF: ret float %dx.dot
half test_dot_half(half p0, half p1) { return dot(p0, p1); }

// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1)
// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1)
// NATIVE_HALF: ret half %dx.dot
// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
// NO_HALF: ret float %dx.dot
half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); }

// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1)
// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1)
// NATIVE_HALF: ret half %dx.dot
// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
// NO_HALF: ret float %dx.dot
half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); }

// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1)
// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1)
// NATIVE_HALF: ret half %dx.dot
// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
// NO_HALF: ret float %dx.dot
half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); }

// CHECK: %dx.dot = fmul float %0, %1
// CHECK: ret float %dx.dot
float test_dot_float(float p0, float p1) { return dot(p0, p1); }

// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); }

// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }

// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }

// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1)
// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }

// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1)
// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }

// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1)
// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
// CHECK: ret float %dx.dot
float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }

// CHECK: %conv = sitofp i32 %1 to float
// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0
// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer
// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat)
// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %splat.splat)
// CHECK: ret float %dx.dot
float test_builtin_dot_float2_int_splat(float2 p0, int p1) {
return dot(p0, p1);
Expand All @@ -168,7 +168,7 @@ float test_builtin_dot_float2_int_splat(float2 p0, int p1) {
// CHECK: %conv = sitofp i32 %1 to float
// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0
// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer
// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat)
// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %splat.splat)
// CHECK: ret float %dx.dot
float test_builtin_dot_float3_int_splat(float3 p0, int p1) {
return dot(p0, p1);
Expand Down
10 changes: 9 additions & 1 deletion llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;

def int_dx_dot :
def int_dx_dot2 :
Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
def int_dx_dot3 :
Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
def int_dx_dot4 :
Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,15 @@ def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
"Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
"Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)) in
def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
"dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1">;
let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)) in
def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
"dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2">;
let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)) in
def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
"dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3">;
def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
"Reads the thread ID">;
def GroupId : DXILOpMapping<94, groupId, int_dx_group_id,
Expand Down
8 changes: 3 additions & 5 deletions llvm/lib/Target/DirectX/DXILOpBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ namespace dxil {

CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
Type *OverloadTy,
llvm::iterator_range<Use *> Args) {
SmallVector<Value *> Args) {
const OpCodeProperty *Prop = getOpCodeProperty(OpCode);

OverloadKind Kind = getOverloadKind(OverloadTy);
Expand All @@ -272,10 +272,8 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
}
SmallVector<Value *> FullArgs;
FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
FullArgs.append(Args.begin(), Args.end());
return B.CreateCall(DXILFn, FullArgs);

return B.CreateCall(DXILFn, Args);
}

Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
Expand Down
5 changes: 2 additions & 3 deletions llvm/lib/Target/DirectX/DXILOpBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#define LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H

#include "DXILConstants.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/ADT/SmallVector.h"

namespace llvm {
class Module;
Expand All @@ -35,8 +35,7 @@ class DXILOpBuilder {
/// \param OverloadTy Overload type of the DXIL Op call constructed
/// \return DXIL Op call constructed
CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
Type *OverloadTy,
llvm::iterator_range<Use *> Args);
Type *OverloadTy, SmallVector<Value *> Args);
Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
static const char *getOpCodeName(dxil::OpCode DXILOp);

Expand Down
55 changes: 53 additions & 2 deletions llvm/lib/Target/DirectX/DXILOpLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,48 @@
using namespace llvm;
using namespace llvm::dxil;

static bool isVectorArgExpansion(Function &F) {
switch (F.getIntrinsicID()) {
case Intrinsic::dx_dot2:
case Intrinsic::dx_dot3:
case Intrinsic::dx_dot4:
return true;
}
return false;
}

static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
SmallVector<Value *, 4> ExtractedElements;
auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
ExtractedElements.push_back(ExtractedElement);
}
return ExtractedElements;
}

static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
IRBuilder<> &Builder) {
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
unsigned NumOperands = Orig->getNumOperands() - 1;
assert(NumOperands > 0);
Value *Arg0 = Orig->getOperand(0);
[[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
assert(VecArg0);
SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
for (unsigned I = 1; I < NumOperands; ++I) {
Value *Arg = Orig->getOperand(I);
[[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
assert(VecArg);
assert(VecArg0->getElementType() == VecArg->getElementType());
assert(VecArg0->getNumElements() == VecArg->getNumElements());
auto NextOperandList = populateOperands(Arg, Builder);
NewOperands.append(NextOperandList.begin(), NextOperandList.end());
}
return NewOperands;
}

static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
IRBuilder<> B(M.getContext());
DXILOpBuilder DXILB(M, B);
Expand All @@ -39,9 +81,18 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
if (!CI)
continue;

SmallVector<Value *> Args;
Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
Args.emplace_back(DXILOpArg);
B.SetInsertPoint(CI);
CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(),
OverloadTy, CI->args());
if (isVectorArgExpansion(F)) {
SmallVector<Value *> NewArgs = argVectorFlatten(CI, B);
Args.append(NewArgs.begin(), NewArgs.end());
} else
Args.append(CI->arg_begin(), CI->arg_end());

CallInst *DXILCI =
DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args);

CI->replaceAllUsesWith(DXILCI);
CI->eraseFromParent();
Expand Down
10 changes: 10 additions & 0 deletions llvm/test/CodeGen/DirectX/dot2_error.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s

; DXIL operation dot2 does not support double overload type
; CHECK: LLVM ERROR: Invalid Overload

define noundef double @dot_double2(<2 x double> noundef %a, <2 x double> noundef %b) {
entry:
%dx.dot = call double @llvm.dx.dot2.v2f64(<2 x double> %a, <2 x double> %b)
ret double %dx.dot
}
10 changes: 10 additions & 0 deletions llvm/test/CodeGen/DirectX/dot3_error.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s

; DXIL operation dot3 does not support double overload type
; CHECK: LLVM ERROR: Invalid Overload

define noundef double @dot_double3(<3 x double> noundef %a, <3 x double> noundef %b) {
entry:
%dx.dot = call double @llvm.dx.dot3.v3f64(<3 x double> %a, <3 x double> %b)
ret double %dx.dot
}
10 changes: 10 additions & 0 deletions llvm/test/CodeGen/DirectX/dot4_error.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s

; DXIL operation dot4 does not support double overload type
; CHECK: LLVM ERROR: Invalid Overload

define noundef double @dot_double4(<4 x double> noundef %a, <4 x double> noundef %b) {
entry:
%dx.dot = call double @llvm.dx.dot4.v4f64(<4 x double> %a, <4 x double> %b)
ret double %dx.dot
}
94 changes: 94 additions & 0 deletions llvm/test/CodeGen/DirectX/fdot.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s

; Make sure dxil operation function calls for dot are generated for int/uint vectors.

; CHECK-LABEL: dot_half2
define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) {
entry:
; CHECK: extractelement <2 x half> %a, i32 0
; CHECK: extractelement <2 x half> %a, i32 1
; CHECK: extractelement <2 x half> %b, i32 0
; CHECK: extractelement <2 x half> %b, i32 1
; CHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
%dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b)
ret half %dx.dot
}

; CHECK-LABEL: dot_half3
define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) {
entry:
; CHECK: extractelement <3 x half> %a, i32 0
; CHECK: extractelement <3 x half> %a, i32 1
; CHECK: extractelement <3 x half> %a, i32 2
; CHECK: extractelement <3 x half> %b, i32 0
; CHECK: extractelement <3 x half> %b, i32 1
; CHECK: extractelement <3 x half> %b, i32 2
; CHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
%dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b)
ret half %dx.dot
}

; CHECK-LABEL: dot_half4
define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
entry:
; CHECK: extractelement <4 x half> %a, i32 0
; CHECK: extractelement <4 x half> %a, i32 1
; CHECK: extractelement <4 x half> %a, i32 2
; CHECK: extractelement <4 x half> %a, i32 3
; CHECK: extractelement <4 x half> %b, i32 0
; CHECK: extractelement <4 x half> %b, i32 1
; CHECK: extractelement <4 x half> %b, i32 2
; CHECK: extractelement <4 x half> %b, i32 3
; CHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
%dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b)
ret half %dx.dot
}

; CHECK-LABEL: dot_float2
define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) {
entry:
; CHECK: extractelement <2 x float> %a, i32 0
; CHECK: extractelement <2 x float> %a, i32 1
; CHECK: extractelement <2 x float> %b, i32 0
; CHECK: extractelement <2 x float> %b, i32 1
; CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
%dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b)
ret float %dx.dot
}

; CHECK-LABEL: dot_float3
define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) {
entry:
; CHECK: extractelement <3 x float> %a, i32 0
; CHECK: extractelement <3 x float> %a, i32 1
; CHECK: extractelement <3 x float> %a, i32 2
; CHECK: extractelement <3 x float> %b, i32 0
; CHECK: extractelement <3 x float> %b, i32 1
; CHECK: extractelement <3 x float> %b, i32 2
; CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
%dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b)
ret float %dx.dot
}

; CHECK-LABEL: dot_float4
define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
entry:
; CHECK: extractelement <4 x float> %a, i32 0
; CHECK: extractelement <4 x float> %a, i32 1
; CHECK: extractelement <4 x float> %a, i32 2
; CHECK: extractelement <4 x float> %a, i32 3
; CHECK: extractelement <4 x float> %b, i32 0
; CHECK: extractelement <4 x float> %b, i32 1
; CHECK: extractelement <4 x float> %b, i32 2
; CHECK: extractelement <4 x float> %b, i32 3
; CHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
%dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b)
ret float %dx.dot
}

declare half @llvm.dx.dot.v2f16(<2 x half> , <2 x half> )
declare half @llvm.dx.dot.v3f16(<3 x half> , <3 x half> )
declare half @llvm.dx.dot.v4f16(<4 x half> , <4 x half> )
declare float @llvm.dx.dot.v2f32(<2 x float>, <2 x float>)
declare float @llvm.dx.dot.v3f32(<3 x float>, <3 x float>)
declare float @llvm.dx.dot.v4f32(<4 x float>, <4 x float>)