Skip to content

[DirectX] Use scalar arguments for @llvm.dx.dot intrinsics #134570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions clang/lib/CodeGen/TargetBuiltins/DirectX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@ Value *CodeGenFunction::EmitDirectXBuiltinExpr(unsigned BuiltinID,
case DirectX::BI__builtin_dx_dot2add: {
Value *A = EmitScalarExpr(E->getArg(0));
Value *B = EmitScalarExpr(E->getArg(1));
Value *C = EmitScalarExpr(E->getArg(2));
Value *Acc = EmitScalarExpr(E->getArg(2));

Value *AX = Builder.CreateExtractElement(A, Builder.getSize(0));
Value *AY = Builder.CreateExtractElement(A, Builder.getSize(1));
Value *BX = Builder.CreateExtractElement(B, Builder.getSize(0));
Value *BY = Builder.CreateExtractElement(B, Builder.getSize(1));

Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
return Builder.CreateIntrinsic(
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
"dx.dot2add");
/*ReturnType=*/Acc->getType(), ID,
ArrayRef<Value *>{Acc, AX, AY, BX, BY}, nullptr, "dx.dot2add");
}
}
return nullptr;
Expand Down
10 changes: 8 additions & 2 deletions clang/test/CodeGenDirectX/Builtins/dot2add.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ typedef half half2 __attribute__((ext_vector_type(2)));
// CHECK-NEXT: [[TMP0:%.*]] = load <2 x half>, ptr [[X_ADDR]], align 4
// CHECK-NEXT: [[TMP1:%.*]] = load <2 x half>, ptr [[Y_ADDR]], align 4
// CHECK-NEXT: [[TMP2:%.*]] = load float, ptr [[Z_ADDR]], align 4
// CHECK-NEXT: [[DX_DOT2ADD:%.*]] = call float @llvm.dx.dot2add.v2f16(<2 x half> [[TMP0]], <2 x half> [[TMP1]], float [[TMP2]])
// CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x half> [[TMP0]], i32 0
// CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x half> [[TMP0]], i32 1
// CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x half> [[TMP1]], i32 0
// CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x half> [[TMP1]], i32 1
// CHECK-NEXT: [[DX_DOT2ADD:%.*]] = call float @llvm.dx.dot2add(float [[TMP2]], half [[TMP3]], half [[TMP4]], half [[TMP5]], half [[TMP6]])
// CHECK-NEXT: ret float [[DX_DOT2ADD]]
//
float test_dot2add(half2 X, half2 Y, float Z) { return __builtin_dx_dot2add(X, Y, Z); }
float test_dot2add(half2 X, half2 Y, float Z) {
return __builtin_dx_dot2add(X, Y, Z);
}
60 changes: 50 additions & 10 deletions clang/test/CodeGenHLSL/builtins/dot2add.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ float test_default_parameter_type(half2 p1, half2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
Expand All @@ -25,7 +29,11 @@ float test_float_arg2_type(half2 p1, float2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
Expand All @@ -37,7 +45,11 @@ float test_float_arg1_type(float2 p1, half2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
Expand All @@ -49,7 +61,11 @@ float test_double_arg3_type(half2 p1, half2 p2, double p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
Expand All @@ -62,7 +78,11 @@ float test_float_arg1_arg2_type(float2 p1, float2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
Expand All @@ -75,7 +95,11 @@ float test_double_arg1_arg2_type(double2 p1, double2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
Expand All @@ -88,7 +112,11 @@ float test_int16_arg1_arg2_type(int16_t2 p1, int16_t2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
Expand All @@ -101,7 +129,11 @@ float test_int32_arg1_arg2_type(int32_t2 p1, int32_t2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
Expand All @@ -114,7 +146,11 @@ float test_int64_arg1_arg2_type(int64_t2 p1, int64_t2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
Expand All @@ -129,7 +165,11 @@ float test_bool_arg1_arg2_type(bool2 p1, bool2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
39 changes: 24 additions & 15 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,27 @@ def int_dx_nclamp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>,
def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;

def int_dx_dot2 :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_dx_dot3 :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_dx_dot4 :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_dx_dot2 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
[
llvm_anyfloat_ty, LLVMMatchType<0>,
LLVMMatchType<0>, LLVMMatchType<0>
],
[IntrNoMem, Commutative]>;
def int_dx_dot3 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
[
llvm_anyfloat_ty, LLVMMatchType<0>,
LLVMMatchType<0>, LLVMMatchType<0>,
LLVMMatchType<0>, LLVMMatchType<0>
],
[IntrNoMem, Commutative]>;
def int_dx_dot4 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
[
llvm_anyfloat_ty, LLVMMatchType<0>,
LLVMMatchType<0>, LLVMMatchType<0>,
LLVMMatchType<0>, LLVMMatchType<0>,
LLVMMatchType<0>, LLVMMatchType<0>
],
[IntrNoMem, Commutative]>;
def int_dx_fdot :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
Expand All @@ -100,9 +109,9 @@ def int_dx_udot :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_dx_dot2add :
DefaultAttrsIntrinsic<[llvm_float_ty],
[llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty],
def int_dx_dot2add :
DefaultAttrsIntrinsic<[llvm_float_ty],
[llvm_float_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty],
[IntrNoMem, Commutative]>;
def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -1078,8 +1078,7 @@ def RawBufferStore : DXILOp<140, rawBufferStore> {
}

def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
let Doc = "dot product of 2 vectors of half having size = 2, returns "
"float";
let Doc = "2D half dot product with accumulate to float";
let intrinsics = [IntrinSelect<int_dx_dot2add>];
let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy];
let result = FloatTy;
Expand Down
13 changes: 10 additions & 3 deletions llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
assert(ATy->getScalarType()->isFloatingPointTy());

Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
switch (AVec->getNumElements()) {
int NumElts = AVec->getNumElements();
switch (NumElts) {
case 2:
DotIntrinsic = Intrinsic::dx_dot2;
break;
Expand All @@ -185,8 +186,14 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
/* gen_crash_diag=*/false);
return nullptr;
}
return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
ArrayRef<Value *>{A, B}, nullptr, "dot");

SmallVector<Value *> Args;
for (int I = 0; I < NumElts; ++I)
Copy link
Member

@farzonl farzonl Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have the new DirectX target builtins I had another thought on how to do this that would eliminate our convert one intrinsic to another intrinsic trick we are doing here. Just expose dot2, dot3, and dot4 in the frontend. If we are comfortable with doing scalarization in the frontend for dot2add then does it make sense to do that for these fdot cases?

Downside is we would need sema check for these 3 new cases which seems like overkill and maybe disqualifies this idea.

You may also have to call a builtin from a builtin via FunctionDecl *lookupBuiltinFunction(..) to avoid having to change the HLSL headers which is maybe another disqualifier.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think having a generic fdot for the frontend to emit makes sense. The implementation being per vector size is a bit odd and wouldn't scale, so if we did ever have a >4 element fdot I feel like it would be nicer for the backend to handle it rather than need the frontend to contend with all of these details.

Args.push_back(Builder.CreateExtractElement(A, Builder.getInt32(I)));
for (int I = 0; I < NumElts; ++I)
Args.push_back(Builder.CreateExtractElement(B, Builder.getInt32(I)));
return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, Args,
nullptr, "dot");
}

// Create the appropriate DXIL float dot intrinsic for the operands of Orig
Expand Down
58 changes: 0 additions & 58 deletions llvm/lib/Target/DirectX/DXILOpLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,52 +33,6 @@
using namespace llvm;
using namespace llvm::dxil;

static bool isVectorArgExpansion(Function &F) {
switch (F.getIntrinsicID()) {
case Intrinsic::dx_dot2:
case Intrinsic::dx_dot3:
case Intrinsic::dx_dot4:
return true;
}
return false;
}

static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
SmallVector<Value *> ExtractedElements;
auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
ExtractedElements.push_back(ExtractedElement);
}
return ExtractedElements;
}

static SmallVector<Value *>
argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder, unsigned NumOperands) {
assert(NumOperands > 0);
Value *Arg0 = Orig->getOperand(0);
[[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
assert(VecArg0);
SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
for (unsigned I = 1; I < NumOperands; ++I) {
Value *Arg = Orig->getOperand(I);
[[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
assert(VecArg);
assert(VecArg0->getElementType() == VecArg->getElementType());
assert(VecArg0->getNumElements() == VecArg->getNumElements());
auto NextOperandList = populateOperands(Arg, Builder);
NewOperands.append(NextOperandList.begin(), NextOperandList.end());
}
return NewOperands;
}

static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
IRBuilder<> &Builder) {
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1);
}

namespace {
class OpLowerer {
Module &M;
Expand Down Expand Up @@ -150,9 +104,6 @@ class OpLowerer {
[[nodiscard]] bool
replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
ArrayRef<IntrinArgSelect> ArgSelects) {
bool IsVectorArgExpansion = isVectorArgExpansion(F);
assert(!(IsVectorArgExpansion && ArgSelects.size()) &&
"Cann't do vector arg expansion when using arg selects.");
return replaceFunction(F, [&](CallInst *CI) -> Error {
OpBuilder.getIRB().SetInsertPoint(CI);
SmallVector<Value *> Args;
Expand All @@ -170,15 +121,6 @@ class OpLowerer {
break;
}
}
} else if (IsVectorArgExpansion) {
Args = argVectorFlatten(CI, OpBuilder.getIRB());
} else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) {
// arg[NumOperands-1] is a pointer and is not needed by our flattening.
// arg[NumOperands-2] also does not need to be flattened because it is a
// scalar.
unsigned NumOperands = CI->getNumOperands() - 2;
Args.push_back(CI->getArgOperand(NumOperands));
Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands));
} else {
Args.append(CI->arg_begin(), CI->arg_end());
}
Expand Down
5 changes: 3 additions & 2 deletions llvm/test/CodeGen/DirectX/dot2_error.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
; CHECK: in function dot_double2
; CHECK-SAME: Cannot create Dot2 operation: Invalid overload type

define noundef double @dot_double2(<2 x double> noundef %a, <2 x double> noundef %b) {
define noundef double @dot_double2(double noundef %a1, double noundef %a2,
double noundef %b1, double noundef %b2) {
entry:
%dx.dot = call double @llvm.dx.dot2.v2f64(<2 x double> %a, <2 x double> %b)
%dx.dot = call double @llvm.dx.dot2(double %a1, double %a2, double %b1, double %b2)
ret double %dx.dot
}
Loading