Skip to content

Commit e3ca0f0

Browse files
committed
optimize expansion, update tests and add scalar test variants
1 parent 547b4da commit e3ca0f0

File tree

3 files changed

+89
-66
lines changed

3 files changed

+89
-66
lines changed

llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
237237
IRBuilder<> Builder(Orig->getParent());
238238
Builder.SetInsertPoint(Orig);
239239

240-
Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
241240
auto *XVec = dyn_cast<FixedVectorType>(Ty);
242241
if (!XVec) {
243242
if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
@@ -253,25 +252,47 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
253252
return true;
254253
}
255254

255+
Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
256256
unsigned XVecSize = XVec->getNumElements();
257-
Value *Sum = Builder.CreateFMul(Elt, Elt);
258-
for (unsigned I = 1; I < XVecSize; I++) {
259-
Elt = Builder.CreateExtractElement(X, I);
260-
Value *Mul = Builder.CreateFMul(Elt, Elt);
261-
Sum = Builder.CreateFAdd(Sum, Mul);
257+
Value *DotProduct = nullptr;
258+
switch (XVecSize) {
259+
case 1:
260+
report_fatal_error(Twine("Invalid input vector: length is zero"),
261+
/* gen_crash_diag=*/false);
262+
break;
263+
case 2:
264+
DotProduct = Builder.CreateIntrinsic(
265+
EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
266+
break;
267+
case 3:
268+
DotProduct = Builder.CreateIntrinsic(
269+
EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
270+
break;
271+
case 4:
272+
DotProduct = Builder.CreateIntrinsic(
273+
EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
274+
break;
275+
default:
276+
report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
277+
/* gen_crash_diag=*/false);
262278
}
263-
Value *Length = Builder.CreateIntrinsic(
264-
EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");
279+
280+
Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
281+
ArrayRef<Value *>{DotProduct},
282+
nullptr, "dx.rsqrt");
265283

266284
// verify that the length is non-zero
267-
if (auto *constantFP = dyn_cast<ConstantFP>(Length)) {
285+
// (if the reciprocal sqrt of the length is non-zero, then the length is
286+
// non-zero)
287+
if (auto *constantFP = dyn_cast<ConstantFP>(Multiplicand)) {
268288
const APFloat &fpVal = constantFP->getValueAPF();
269289
if (fpVal.isZero())
270290
report_fatal_error(Twine("Invalid input vector: length is zero"),
271291
/* gen_crash_diag=*/false);
272292
}
273-
Value *LengthVec = Builder.CreateVectorSplat(XVecSize, Length);
274-
Value *Result = Builder.CreateFDiv(X, LengthVec);
293+
294+
Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
295+
Value *Result = Builder.CreateFMul(X, MultiplicandVec);
275296

276297
Orig->replaceAllUsesWith(Result);
277298
Orig->eraseFromParent();

llvm/test/CodeGen/DirectX/normalize.ll

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,23 @@ declare <2 x float> @llvm.dx.normalize.v2f32(<2 x float>)
1313
declare <3 x float> @llvm.dx.normalize.v3f32(<3 x float>)
1414
declare <4 x float> @llvm.dx.normalize.v4f32(<4 x float>)
1515

16+
define noundef half @test_normalize_half(half noundef %p0) {
17+
entry:
18+
; CHECK: fdiv half %p0, %p0
19+
%hlsl.normalize = call half @llvm.dx.normalize.f16(half %p0)
20+
ret half %hlsl.normalize
21+
}
22+
1623
define noundef <2 x half> @test_normalize_half2(<2 x half> noundef %p0) {
1724
entry:
1825
; CHECK: extractelement <2 x half> %{{.*}}, i64 0
19-
; CHECK: fmul half %{{.*}}, %{{.*}}
20-
; CHECK: extractelement <2 x half> %{{.*}}, i64 1
21-
; CHECK: fmul half %{{.*}}, %{{.*}}
22-
; CHECK: fadd half %{{.*}}, %{{.*}}
23-
; EXPCHECK: call half @llvm.sqrt.f16(half %{{.*}})
24-
; DOPCHECK: call half @dx.op.unary.f16(i32 24, half %{{.*}})
26+
; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
27+
; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
28+
; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
29+
; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
30+
; CHECK: insertelement <2 x half> poison, half %{{.*}}, i64 0
31+
; CHECK: shufflevector <2 x half> %{{.*}}, <2 x half> poison, <2 x i32> zeroinitializer
32+
; CHECK: fmul <2 x half> %{{.*}}, %{{.*}}
2533

2634
%hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16(<2 x half> %p0)
2735
ret <2 x half> %hlsl.normalize
@@ -30,15 +38,13 @@ entry:
3038
define noundef <3 x half> @test_normalize_half3(<3 x half> noundef %p0) {
3139
entry:
3240
; CHECK: extractelement <3 x half> %{{.*}}, i64 0
33-
; CHECK: fmul half %{{.*}}, %{{.*}}
34-
; CHECK: extractelement <3 x half> %{{.*}}, i64 1
35-
; CHECK: fmul half %{{.*}}, %{{.*}}
36-
; CHECK: fadd half %{{.*}}, %{{.*}}
37-
; CHECK: extractelement <3 x half> %{{.*}}, i64 2
38-
; CHECK: fmul half %{{.*}}, %{{.*}}
39-
; CHECK: fadd half %{{.*}}, %{{.*}}
40-
; EXPCHECK: call half @llvm.sqrt.f16(half %{{.*}})
41-
; DOPCHECK: call half @dx.op.unary.f16(i32 24, half %{{.*}})
41+
; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
42+
; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
43+
; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
44+
; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
45+
; CHECK: insertelement <3 x half> poison, half %{{.*}}, i64 0
46+
; CHECK: shufflevector <3 x half> %{{.*}}, <3 x half> poison, <3 x i32> zeroinitializer
47+
; CHECK: fmul <3 x half> %{{.*}}, %{{.*}}
4248

4349
%hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16(<3 x half> %p0)
4450
ret <3 x half> %hlsl.normalize
@@ -47,32 +53,35 @@ entry:
4753
define noundef <4 x half> @test_normalize_half4(<4 x half> noundef %p0) {
4854
entry:
4955
; CHECK: extractelement <4 x half> %{{.*}}, i64 0
50-
; CHECK: fmul half %{{.*}}, %{{.*}}
51-
; CHECK: extractelement <4 x half> %{{.*}}, i64 1
52-
; CHECK: fmul half %{{.*}}, %{{.*}}
53-
; CHECK: fadd half %{{.*}}, %{{.*}}
54-
; CHECK: extractelement <4 x half> %{{.*}}, i64 2
55-
; CHECK: fmul half %{{.*}}, %{{.*}}
56-
; CHECK: fadd half %{{.*}}, %{{.*}}
57-
; CHECK: extractelement <4 x half> %{{.*}}, i64 3
58-
; CHECK: fmul half %{{.*}}, %{{.*}}
59-
; CHECK: fadd half %{{.*}}, %{{.*}}
60-
; EXPCHECK: call half @llvm.sqrt.f16(half %{{.*}})
61-
; DOPCHECK: call half @dx.op.unary.f16(i32 24, half %{{.*}})
56+
; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
57+
; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
58+
; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
59+
; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
60+
; CHECK: insertelement <4 x half> poison, half %{{.*}}, i64 0
61+
; CHECK: shufflevector <4 x half> %{{.*}}, <4 x half> poison, <4 x i32> zeroinitializer
62+
; CHECK: fmul <4 x half> %{{.*}}, %{{.*}}
6263

6364
%hlsl.normalize = call <4 x half> @llvm.dx.normalize.v4f16(<4 x half> %p0)
6465
ret <4 x half> %hlsl.normalize
6566
}
6667

68+
define noundef float @test_normalize_float(float noundef %p0) {
69+
entry:
70+
; CHECK: fdiv float %p0, %p0
71+
%hlsl.normalize = call float @llvm.dx.normalize.f32(float %p0)
72+
ret float %hlsl.normalize
73+
}
74+
6775
define noundef <2 x float> @test_normalize_float2(<2 x float> noundef %p0) {
6876
entry:
6977
; CHECK: extractelement <2 x float> %{{.*}}, i64 0
70-
; CHECK: fmul float %{{.*}}, %{{.*}}
71-
; CHECK: extractelement <2 x float> %{{.*}}, i64 1
72-
; CHECK: fmul float %{{.*}}, %{{.*}}
73-
; CHECK: fadd float %{{.*}}, %{{.*}}
74-
; EXPCHECK: call float @llvm.sqrt.f32(float %{{.*}})
75-
; DOPCHECK: call float @dx.op.unary.f32(i32 24, float %{{.*}})
78+
; EXPCHECK: call float @llvm.dx.dot2.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}})
79+
; DOPCHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
80+
; EXPCHECK: call float @llvm.dx.rsqrt.f32(float %{{.*}})
81+
; DOPCHECK: call float @dx.op.unary.f32(i32 25, float %{{.*}})
82+
; CHECK: insertelement <2 x float> poison, float %{{.*}}, i64 0
83+
; CHECK: shufflevector <2 x float> %{{.*}}, <2 x float> poison, <2 x i32> zeroinitializer
84+
; CHECK: fmul <2 x float> %{{.*}}, %{{.*}}
7685

7786
%hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(<2 x float> %p0)
7887
ret <2 x float> %hlsl.normalize
@@ -81,15 +90,13 @@ entry:
8190
define noundef <3 x float> @test_normalize_float3(<3 x float> noundef %p0) {
8291
entry:
8392
; CHECK: extractelement <3 x float> %{{.*}}, i64 0
84-
; CHECK: fmul float %{{.*}}, %{{.*}}
85-
; CHECK: extractelement <3 x float> %{{.*}}, i64 1
86-
; CHECK: fmul float %{{.*}}, %{{.*}}
87-
; CHECK: fadd float %{{.*}}, %{{.*}}
88-
; CHECK: extractelement <3 x float> %{{.*}}, i64 2
89-
; CHECK: fmul float %{{.*}}, %{{.*}}
90-
; CHECK: fadd float %{{.*}}, %{{.*}}
91-
; EXPCHECK: call float @llvm.sqrt.f32(float %{{.*}})
92-
; DOPCHECK: call float @dx.op.unary.f32(i32 24, float %{{.*}})
93+
; EXPCHECK: call float @llvm.dx.dot3.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}})
94+
; DOPCHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
95+
; EXPCHECK: call float @llvm.dx.rsqrt.f32(float %{{.*}})
96+
; DOPCHECK: call float @dx.op.unary.f32(i32 25, float %{{.*}})
97+
; CHECK: insertelement <3 x float> poison, float %{{.*}}, i64 0
98+
; CHECK: shufflevector <3 x float> %{{.*}}, <3 x float> poison, <3 x i32> zeroinitializer
99+
; CHECK: fmul <3 x float> %{{.*}}, %{{.*}}
93100

94101
%hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(<3 x float> %p0)
95102
ret <3 x float> %hlsl.normalize
@@ -98,18 +105,13 @@ entry:
98105
define noundef <4 x float> @test_normalize_float4(<4 x float> noundef %p0) {
99106
entry:
100107
; CHECK: extractelement <4 x float> %{{.*}}, i64 0
101-
; CHECK: fmul float %{{.*}}, %{{.*}}
102-
; CHECK: extractelement <4 x float> %{{.*}}, i64 1
103-
; CHECK: fmul float %{{.*}}, %{{.*}}
104-
; CHECK: fadd float %{{.*}}, %{{.*}}
105-
; CHECK: extractelement <4 x float> %{{.*}}, i64 2
106-
; CHECK: fmul float %{{.*}}, %{{.*}}
107-
; CHECK: fadd float %{{.*}}, %{{.*}}
108-
; CHECK: extractelement <4 x float> %{{.*}}, i64 3
109-
; CHECK: fmul float %{{.*}}, %{{.*}}
110-
; CHECK: fadd float %{{.*}}, %{{.*}}
111-
; EXPCHECK: call float @llvm.sqrt.f32(float %{{.*}})
112-
; DOPCHECK: call float @dx.op.unary.f32(i32 24, float %{{.*}})
108+
; EXPCHECK: call float @llvm.dx.dot4.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}})
109+
; DOPCHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
110+
; EXPCHECK: call float @llvm.dx.rsqrt.f32(float %{{.*}})
111+
; DOPCHECK: call float @dx.op.unary.f32(i32 25, float %{{.*}})
112+
; CHECK: insertelement <4 x float> poison, float %{{.*}}, i64 0
113+
; CHECK: shufflevector <4 x float> %{{.*}}, <4 x float> poison, <4 x i32> zeroinitializer
114+
; CHECK: fmul <4 x float> %{{.*}}, %{{.*}}
113115

114116
%hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(<4 x float> %p0)
115117
ret <4 x float> %hlsl.normalize

llvm/test/CodeGen/DirectX/normalize_error.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
22

33
; DXIL operation normalize does not support double overload type
4-
; CHECK: Cannot create Sqrt operation: Invalid overload type
4+
; CHECK: Cannot create Dot2 operation: Invalid overload type
55

66
define noundef <2 x double> @test_normalize_double2(<2 x double> noundef %p0) {
77
entry:

0 commit comments

Comments
 (0)