Skip to content

Commit ff5816a

Browse files
authored
[DirectX] Add all lowering (#105787)
- DXILIntrinsicExpansion.cpp: Modify `any` codegen expansion to work for `all` - DirectX\all.ll: Add test case completes #88946
1 parent 4bab038 commit ff5816a

File tree

2 files changed

+110
-24
lines changed

2 files changed

+110
-24
lines changed

llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ static bool isIntrinsicExpansion(Function &F) {
3838
case Intrinsic::log:
3939
case Intrinsic::log10:
4040
case Intrinsic::pow:
41+
case Intrinsic::dx_all:
4142
case Intrinsic::dx_any:
4243
case Intrinsic::dx_clamp:
4344
case Intrinsic::dx_uclamp:
@@ -54,8 +55,7 @@ static bool isIntrinsicExpansion(Function &F) {
5455

5556
static Value *expandAbs(CallInst *Orig) {
5657
Value *X = Orig->getOperand(0);
57-
IRBuilder<> Builder(Orig->getParent());
58-
Builder.SetInsertPoint(Orig);
58+
IRBuilder<> Builder(Orig);
5959
Type *Ty = X->getType();
6060
Type *EltTy = Ty->getScalarType();
6161
Constant *Zero = Ty->isVectorTy()
@@ -148,8 +148,7 @@ static Value *expandIntegerDotIntrinsic(CallInst *Orig,
148148

149149
static Value *expandExpIntrinsic(CallInst *Orig) {
150150
Value *X = Orig->getOperand(0);
151-
IRBuilder<> Builder(Orig->getParent());
152-
Builder.SetInsertPoint(Orig);
151+
IRBuilder<> Builder(Orig);
153152
Type *Ty = X->getType();
154153
Type *EltTy = Ty->getScalarType();
155154
Constant *Log2eConst =
@@ -166,13 +165,21 @@ static Value *expandExpIntrinsic(CallInst *Orig) {
166165
return Exp2Call;
167166
}
168167

169-
static Value *expandAnyIntrinsic(CallInst *Orig) {
168+
static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
169+
Intrinsic::ID intrinsicId) {
170170
Value *X = Orig->getOperand(0);
171-
IRBuilder<> Builder(Orig->getParent());
172-
Builder.SetInsertPoint(Orig);
171+
IRBuilder<> Builder(Orig);
173172
Type *Ty = X->getType();
174173
Type *EltTy = Ty->getScalarType();
175174

175+
auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
176+
Value *Elt) {
177+
if (IntrinsicId == Intrinsic::dx_any)
178+
return Builder.CreateOr(Result, Elt);
179+
assert(IntrinsicId == Intrinsic::dx_all);
180+
return Builder.CreateAnd(Result, Elt);
181+
};
182+
176183
Value *Result = nullptr;
177184
if (!Ty->isVectorTy()) {
178185
Result = EltTy->isFloatingPointTy()
@@ -193,16 +200,15 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
193200
Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
194201
for (unsigned I = 1; I < XVec->getNumElements(); I++) {
195202
Value *Elt = Builder.CreateExtractElement(Cond, I);
196-
Result = Builder.CreateOr(Result, Elt);
203+
Result = ApplyOp(intrinsicId, Result, Elt);
197204
}
198205
}
199206
return Result;
200207
}
201208

202209
static Value *expandLengthIntrinsic(CallInst *Orig) {
203210
Value *X = Orig->getOperand(0);
204-
IRBuilder<> Builder(Orig->getParent());
205-
Builder.SetInsertPoint(Orig);
211+
IRBuilder<> Builder(Orig);
206212
Type *Ty = X->getType();
207213
Type *EltTy = Ty->getScalarType();
208214

@@ -230,8 +236,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
230236
Value *X = Orig->getOperand(0);
231237
Value *Y = Orig->getOperand(1);
232238
Value *S = Orig->getOperand(2);
233-
IRBuilder<> Builder(Orig->getParent());
234-
Builder.SetInsertPoint(Orig);
239+
IRBuilder<> Builder(Orig);
235240
auto *V = Builder.CreateFSub(Y, X);
236241
V = Builder.CreateFMul(S, V);
237242
return Builder.CreateFAdd(X, V, "dx.lerp");
@@ -240,8 +245,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
240245
static Value *expandLogIntrinsic(CallInst *Orig,
241246
float LogConstVal = numbers::ln2f) {
242247
Value *X = Orig->getOperand(0);
243-
IRBuilder<> Builder(Orig->getParent());
244-
Builder.SetInsertPoint(Orig);
248+
IRBuilder<> Builder(Orig);
245249
Type *Ty = X->getType();
246250
Type *EltTy = Ty->getScalarType();
247251
Constant *Ln2Const =
@@ -266,8 +270,7 @@ static Value *expandNormalizeIntrinsic(CallInst *Orig) {
266270
Value *X = Orig->getOperand(0);
267271
Type *Ty = Orig->getType();
268272
Type *EltTy = Ty->getScalarType();
269-
IRBuilder<> Builder(Orig->getParent());
270-
Builder.SetInsertPoint(Orig);
273+
IRBuilder<> Builder(Orig);
271274

272275
auto *XVec = dyn_cast<FixedVectorType>(Ty);
273276
if (!XVec) {
@@ -305,8 +308,7 @@ static Value *expandPowIntrinsic(CallInst *Orig) {
305308
Value *X = Orig->getOperand(0);
306309
Value *Y = Orig->getOperand(1);
307310
Type *Ty = X->getType();
308-
IRBuilder<> Builder(Orig->getParent());
309-
Builder.SetInsertPoint(Orig);
311+
IRBuilder<> Builder(Orig);
310312

311313
auto *Log2Call =
312314
Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
@@ -350,8 +352,7 @@ static Value *expandClampIntrinsic(CallInst *Orig,
350352
Value *Min = Orig->getOperand(1);
351353
Value *Max = Orig->getOperand(2);
352354
Type *Ty = X->getType();
353-
IRBuilder<> Builder(Orig->getParent());
354-
Builder.SetInsertPoint(Orig);
355+
IRBuilder<> Builder(Orig);
355356
auto *MaxCall = Builder.CreateIntrinsic(
356357
Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
357358
return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
@@ -360,7 +361,8 @@ static Value *expandClampIntrinsic(CallInst *Orig,
360361

361362
static bool expandIntrinsic(Function &F, CallInst *Orig) {
362363
Value *Result = nullptr;
363-
switch (F.getIntrinsicID()) {
364+
Intrinsic::ID IntrinsicId = F.getIntrinsicID();
365+
switch (IntrinsicId) {
364366
case Intrinsic::abs:
365367
Result = expandAbs(Orig);
366368
break;
@@ -376,12 +378,13 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
376378
case Intrinsic::pow:
377379
Result = expandPowIntrinsic(Orig);
378380
break;
381+
case Intrinsic::dx_all:
379382
case Intrinsic::dx_any:
380-
Result = expandAnyIntrinsic(Orig);
383+
Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
381384
break;
382385
case Intrinsic::dx_uclamp:
383386
case Intrinsic::dx_clamp:
384-
Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
387+
Result = expandClampIntrinsic(Orig, IntrinsicId);
385388
break;
386389
case Intrinsic::dx_lerp:
387390
Result = expandLerpIntrinsic(Orig);
@@ -397,7 +400,7 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
397400
break;
398401
case Intrinsic::dx_sdot:
399402
case Intrinsic::dx_udot:
400-
Result = expandIntegerDotIntrinsic(Orig, F.getIntrinsicID());
403+
Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
401404
break;
402405
}
403406

llvm/test/CodeGen/DirectX/all.ll

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
; RUN: opt -S -passes=dxil-intrinsic-expansion,dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library < %s | FileCheck %s
2+
3+
; Make sure dxil operation function calls for all are generated for float and half.
4+
5+
; CHECK-LABEL: all_bool
6+
; CHECK: icmp ne i1 %{{.*}}, false
7+
define noundef i1 @all_bool(i1 noundef %p0) {
8+
entry:
9+
%dx.all = call i1 @llvm.dx.all.i1(i1 %p0)
10+
ret i1 %dx.all
11+
}
12+
13+
; CHECK-LABEL: all_int64_t
14+
; CHECK: icmp ne i64 %{{.*}}, 0
15+
define noundef i1 @all_int64_t(i64 noundef %p0) {
16+
entry:
17+
%dx.all = call i1 @llvm.dx.all.i64(i64 %p0)
18+
ret i1 %dx.all
19+
}
20+
21+
; CHECK-LABEL: all_int
22+
; CHECK: icmp ne i32 %{{.*}}, 0
23+
define noundef i1 @all_int(i32 noundef %p0) {
24+
entry:
25+
%dx.all = call i1 @llvm.dx.all.i32(i32 %p0)
26+
ret i1 %dx.all
27+
}
28+
29+
; CHECK-LABEL: all_int16_t
30+
; CHECK: icmp ne i16 %{{.*}}, 0
31+
define noundef i1 @all_int16_t(i16 noundef %p0) {
32+
entry:
33+
%dx.all = call i1 @llvm.dx.all.i16(i16 %p0)
34+
ret i1 %dx.all
35+
}
36+
37+
; CHECK-LABEL: all_double
38+
; CHECK: fcmp une double %{{.*}}, 0.000000e+00
39+
define noundef i1 @all_double(double noundef %p0) {
40+
entry:
41+
%dx.all = call i1 @llvm.dx.all.f64(double %p0)
42+
ret i1 %dx.all
43+
}
44+
45+
; CHECK-LABEL: all_float
46+
; CHECK: fcmp une float %{{.*}}, 0.000000e+00
47+
define noundef i1 @all_float(float noundef %p0) {
48+
entry:
49+
%dx.all = call i1 @llvm.dx.all.f32(float %p0)
50+
ret i1 %dx.all
51+
}
52+
53+
; CHECK-LABEL: all_half
54+
; CHECK: fcmp une half %{{.*}}, 0xH0000
55+
define noundef i1 @all_half(half noundef %p0) {
56+
entry:
57+
%dx.all = call i1 @llvm.dx.all.f16(half %p0)
58+
ret i1 %dx.all
59+
}
60+
61+
; CHECK-LABEL: all_bool4
62+
; CHECK: icmp ne <4 x i1> %{{.*}}, zeroinitialize
63+
; CHECK: extractelement <4 x i1> %{{.*}}, i64 0
64+
; CHECK: extractelement <4 x i1> %{{.*}}, i64 1
65+
; CHECK: and i1 %{{.*}}, %{{.*}}
66+
; CHECK: extractelement <4 x i1> %{{.*}}, i64 2
67+
; CHECK: and i1 %{{.*}}, %{{.*}}
68+
; CHECK: extractelement <4 x i1> %{{.*}}, i64 3
69+
; CHECK: and i1 %{{.*}}, %{{.*}}
70+
define noundef i1 @all_bool4(<4 x i1> noundef %p0) {
71+
entry:
72+
%dx.all = call i1 @llvm.dx.all.v4i1(<4 x i1> %p0)
73+
ret i1 %dx.all
74+
}
75+
76+
declare i1 @llvm.dx.all.v4i1(<4 x i1>)
77+
declare i1 @llvm.dx.all.i1(i1)
78+
declare i1 @llvm.dx.all.i16(i16)
79+
declare i1 @llvm.dx.all.i32(i32)
80+
declare i1 @llvm.dx.all.i64(i64)
81+
declare i1 @llvm.dx.all.f16(half)
82+
declare i1 @llvm.dx.all.f32(float)
83+
declare i1 @llvm.dx.all.f64(double)

0 commit comments

Comments
 (0)