-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[DirectX] Add all
lowering
#105787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DirectX] Add all
lowering
#105787
Conversation
- DXILIntrinsicExpansion.cpp: Modify `any` codegen expansion to work for `all` - DirectX\all.ll: Add test case
@llvm/pr-subscribers-backend-directx Author: Farzon Lotfi (farzonl) Changes
completes #88946 Full diff: https://github.com/llvm/llvm-project/pull/105787.diff 2 Files Affected:
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e49169cff8aa86..2daa4f825c3b25 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -38,6 +38,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::log:
case Intrinsic::log10:
case Intrinsic::pow:
+ case Intrinsic::dx_all:
case Intrinsic::dx_any:
case Intrinsic::dx_clamp:
case Intrinsic::dx_uclamp:
@@ -54,8 +55,7 @@ static bool isIntrinsicExpansion(Function &F) {
static Value *expandAbs(CallInst *Orig) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Zero = Ty->isVectorTy()
@@ -148,8 +148,7 @@ static Value *expandIntegerDotIntrinsic(CallInst *Orig,
static Value *expandExpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Log2eConst =
@@ -166,13 +165,21 @@ static Value *expandExpIntrinsic(CallInst *Orig) {
return Exp2Call;
}
-static Value *expandAnyIntrinsic(CallInst *Orig) {
+static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
+ Intrinsic::ID intrinsicId) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
+ auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
+ Value *Elt) {
+ if (IntrinsicId == Intrinsic::dx_any)
+ return Builder.CreateOr(Result, Elt);
+ assert(IntrinsicId == Intrinsic::dx_all);
+ return Builder.CreateAnd(Result, Elt);
+ };
+
Value *Result = nullptr;
if (!Ty->isVectorTy()) {
Result = EltTy->isFloatingPointTy()
@@ -193,7 +200,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
for (unsigned I = 1; I < XVec->getNumElements(); I++) {
Value *Elt = Builder.CreateExtractElement(Cond, I);
- Result = Builder.CreateOr(Result, Elt);
+ Result = ApplyOp(intrinsicId, Result, Elt);
}
}
return Result;
@@ -201,8 +208,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
static Value *expandLengthIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
@@ -230,8 +236,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Value *S = Orig->getOperand(2);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
auto *V = Builder.CreateFSub(Y, X);
V = Builder.CreateFMul(S, V);
return Builder.CreateFAdd(X, V, "dx.lerp");
@@ -240,8 +245,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
static Value *expandLogIntrinsic(CallInst *Orig,
float LogConstVal = numbers::ln2f) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Ln2Const =
@@ -266,8 +270,7 @@ static Value *expandNormalizeIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Type *Ty = Orig->getType();
Type *EltTy = Ty->getScalarType();
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
auto *XVec = dyn_cast<FixedVectorType>(Ty);
if (!XVec) {
@@ -305,8 +308,7 @@ static Value *expandPowIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Type *Ty = X->getType();
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
auto *Log2Call =
Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
@@ -350,8 +352,7 @@ static Value *expandClampIntrinsic(CallInst *Orig,
Value *Min = Orig->getOperand(1);
Value *Max = Orig->getOperand(2);
Type *Ty = X->getType();
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
auto *MaxCall = Builder.CreateIntrinsic(
Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
@@ -360,7 +361,8 @@ static Value *expandClampIntrinsic(CallInst *Orig,
static bool expandIntrinsic(Function &F, CallInst *Orig) {
Value *Result = nullptr;
- switch (F.getIntrinsicID()) {
+ Intrinsic::ID IntrinsicId = F.getIntrinsicID();
+ switch (IntrinsicId) {
case Intrinsic::abs:
Result = expandAbs(Orig);
break;
@@ -376,12 +378,13 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
case Intrinsic::pow:
Result = expandPowIntrinsic(Orig);
break;
+ case Intrinsic::dx_all:
case Intrinsic::dx_any:
- Result = expandAnyIntrinsic(Orig);
+ Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
break;
case Intrinsic::dx_uclamp:
case Intrinsic::dx_clamp:
- Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
+ Result = expandClampIntrinsic(Orig, IntrinsicId);
break;
case Intrinsic::dx_lerp:
Result = expandLerpIntrinsic(Orig);
@@ -397,7 +400,7 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
break;
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
- Result = expandIntegerDotIntrinsic(Orig, F.getIntrinsicID());
+ Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
break;
}
diff --git a/llvm/test/CodeGen/DirectX/all.ll b/llvm/test/CodeGen/DirectX/all.ll
new file mode 100644
index 00000000000000..c82d14f05ee640
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/all.ll
@@ -0,0 +1,113 @@
+; RUN: opt -S -passes=dxil-intrinsic-expansion,dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library < %s | FileCheck %s
+
+; Make sure dxil operation function calls for all are generated for float and half.
+
+; CHECK-LABEL: all_bool
+; CHECK: icmp ne i1 %{{.*}}, false
+define noundef i1 @all_bool(i1 noundef %p0) {
+entry:
+ %p0.addr = alloca i8, align 1
+ %frombool = zext i1 %p0 to i8
+ store i8 %frombool, ptr %p0.addr, align 1
+ %0 = load i8, ptr %p0.addr, align 1
+ %tobool = trunc i8 %0 to i1
+ %dx.all = call i1 @llvm.dx.all.i1(i1 %tobool)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int64_t
+; CHECK: icmp ne i64 %{{.*}}, 0
+define noundef i1 @all_int64_t(i64 noundef %p0) {
+entry:
+ %p0.addr = alloca i64, align 8
+ store i64 %p0, ptr %p0.addr, align 8
+ %0 = load i64, ptr %p0.addr, align 8
+ %dx.all = call i1 @llvm.dx.all.i64(i64 %0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int
+; CHECK: icmp ne i32 %{{.*}}, 0
+define noundef i1 @all_int(i32 noundef %p0) {
+entry:
+ %p0.addr = alloca i32, align 4
+ store i32 %p0, ptr %p0.addr, align 4
+ %0 = load i32, ptr %p0.addr, align 4
+ %dx.all = call i1 @llvm.dx.all.i32(i32 %0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int16_t
+; CHECK: icmp ne i16 %{{.*}}, 0
+define noundef i1 @all_int16_t(i16 noundef %p0) {
+entry:
+ %p0.addr = alloca i16, align 2
+ store i16 %p0, ptr %p0.addr, align 2
+ %0 = load i16, ptr %p0.addr, align 2
+ %dx.all = call i1 @llvm.dx.all.i16(i16 %0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_double
+; CHECK: fcmp une double %{{.*}}, 0.000000e+00
+define noundef i1 @all_double(double noundef %p0) {
+entry:
+ %p0.addr = alloca double, align 8
+ store double %p0, ptr %p0.addr, align 8
+ %0 = load double, ptr %p0.addr, align 8
+ %dx.all = call i1 @llvm.dx.all.f64(double %0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_float
+; CHECK: fcmp une float %{{.*}}, 0.000000e+00
+define noundef i1 @all_float(float noundef %p0) {
+entry:
+ %p0.addr = alloca float, align 4
+ store float %p0, ptr %p0.addr, align 4
+ %0 = load float, ptr %p0.addr, align 4
+ %dx.all = call i1 @llvm.dx.all.f32(float %0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_half
+; CHECK: fcmp une half %{{.*}}, 0xH0000
+define noundef i1 @all_half(half noundef %p0) {
+entry:
+ %p0.addr = alloca half, align 2
+ store half %p0, ptr %p0.addr, align 2
+ %0 = load half, ptr %p0.addr, align 2
+ %dx.all = call i1 @llvm.dx.all.f16(half %0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_bool4
+; CHECK: icmp ne <4 x i1> %extractvec, zeroinitialize
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 0
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 1
+; CHECK: and i1 %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 2
+; CHECK: and i1 %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 3
+; CHECK: and i1 %{{.*}}, %{{.*}}
+define noundef i1 @all_bool4(<4 x i1> noundef %p0) {
+entry:
+ %p0.addr = alloca i8, align 1
+ %insertvec = shufflevector <4 x i1> %p0, <4 x i1> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison>
+ %0 = bitcast <8 x i1> %insertvec to i8
+ store i8 %0, ptr %p0.addr, align 1
+ %load_bits = load i8, ptr %p0.addr, align 1
+ %1 = bitcast i8 %load_bits to <8 x i1>
+ %extractvec = shufflevector <8 x i1> %1, <8 x i1> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ %dx.all = call i1 @llvm.dx.all.v4i1(<4 x i1> %extractvec)
+ ret i1 %dx.all
+}
+
+declare i1 @llvm.dx.all.v4i1(<4 x i1>)
+declare i1 @llvm.dx.all.i1(i1)
+declare i1 @llvm.dx.all.i16(i16)
+declare i1 @llvm.dx.all.i32(i32)
+declare i1 @llvm.dx.all.i64(i64)
+declare i1 @llvm.dx.all.f16(half)
+declare i1 @llvm.dx.all.f32(float)
+declare i1 @llvm.dx.all.f64(double)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, but I think simplifying the tests would be useful.
any
codegen expansion to work forall
completes #88946