Skip to content

[DirectX] Add all lowering #105787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 26, 2024
Merged

Conversation

farzonl
Copy link
Member

@farzonl farzonl commented Aug 23, 2024

  • DXILIntrinsicExpansion.cpp: Modify any codegen expansion to work for all
  • DirectX\all.ll: Add test case

completes #88946

- DXILIntrinsicExpansion.cpp: Modify `any` codegen expansion to work for `all`
- DirectX\all.ll: Add test case
@llvmbot
Copy link
Member

llvmbot commented Aug 23, 2024

@llvm/pr-subscribers-backend-directx

Author: Farzon Lotfi (farzonl)

Changes
  • DXILIntrinsicExpansion.cpp: Modify any codegen expansion to work for all
  • DirectX\all.ll: Add test case

completes #88946


Full diff: https://github.com/llvm/llvm-project/pull/105787.diff

2 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+27-24)
  • (added) llvm/test/CodeGen/DirectX/all.ll (+113)
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e49169cff8aa86..2daa4f825c3b25 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -38,6 +38,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::log:
   case Intrinsic::log10:
   case Intrinsic::pow:
+  case Intrinsic::dx_all:
   case Intrinsic::dx_any:
   case Intrinsic::dx_clamp:
   case Intrinsic::dx_uclamp:
@@ -54,8 +55,7 @@ static bool isIntrinsicExpansion(Function &F) {
 
 static Value *expandAbs(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
   Constant *Zero = Ty->isVectorTy()
@@ -148,8 +148,7 @@ static Value *expandIntegerDotIntrinsic(CallInst *Orig,
 
 static Value *expandExpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
   Constant *Log2eConst =
@@ -166,13 +165,21 @@ static Value *expandExpIntrinsic(CallInst *Orig) {
   return Exp2Call;
 }
 
-static Value *expandAnyIntrinsic(CallInst *Orig) {
+static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
+                                      Intrinsic::ID intrinsicId) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
 
+  auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
+                            Value *Elt) {
+    if (IntrinsicId == Intrinsic::dx_any)
+      return Builder.CreateOr(Result, Elt);
+    assert(IntrinsicId == Intrinsic::dx_all);
+    return Builder.CreateAnd(Result, Elt);
+  };
+
   Value *Result = nullptr;
   if (!Ty->isVectorTy()) {
     Result = EltTy->isFloatingPointTy()
@@ -193,7 +200,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
     Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
     for (unsigned I = 1; I < XVec->getNumElements(); I++) {
       Value *Elt = Builder.CreateExtractElement(Cond, I);
-      Result = Builder.CreateOr(Result, Elt);
+      Result = ApplyOp(intrinsicId, Result, Elt);
     }
   }
   return Result;
@@ -201,8 +208,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
 
 static Value *expandLengthIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
 
@@ -230,8 +236,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
   Value *S = Orig->getOperand(2);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   auto *V = Builder.CreateFSub(Y, X);
   V = Builder.CreateFMul(S, V);
   return Builder.CreateFAdd(X, V, "dx.lerp");
@@ -240,8 +245,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
 static Value *expandLogIntrinsic(CallInst *Orig,
                                  float LogConstVal = numbers::ln2f) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
   Constant *Ln2Const =
@@ -266,8 +270,7 @@ static Value *expandNormalizeIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Type *Ty = Orig->getType();
   Type *EltTy = Ty->getScalarType();
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
 
   auto *XVec = dyn_cast<FixedVectorType>(Ty);
   if (!XVec) {
@@ -305,8 +308,7 @@ static Value *expandPowIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
   Type *Ty = X->getType();
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
 
   auto *Log2Call =
       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
@@ -350,8 +352,7 @@ static Value *expandClampIntrinsic(CallInst *Orig,
   Value *Min = Orig->getOperand(1);
   Value *Max = Orig->getOperand(2);
   Type *Ty = X->getType();
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   auto *MaxCall = Builder.CreateIntrinsic(
       Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
   return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
@@ -360,7 +361,8 @@ static Value *expandClampIntrinsic(CallInst *Orig,
 
 static bool expandIntrinsic(Function &F, CallInst *Orig) {
   Value *Result = nullptr;
-  switch (F.getIntrinsicID()) {
+  Intrinsic::ID IntrinsicId = F.getIntrinsicID();
+  switch (IntrinsicId) {
   case Intrinsic::abs:
     Result = expandAbs(Orig);
     break;
@@ -376,12 +378,13 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
   case Intrinsic::pow:
     Result = expandPowIntrinsic(Orig);
     break;
+  case Intrinsic::dx_all:
   case Intrinsic::dx_any:
-    Result = expandAnyIntrinsic(Orig);
+    Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
     break;
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_clamp:
-    Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
+    Result = expandClampIntrinsic(Orig, IntrinsicId);
     break;
   case Intrinsic::dx_lerp:
     Result = expandLerpIntrinsic(Orig);
@@ -397,7 +400,7 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
     break;
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
-    Result = expandIntegerDotIntrinsic(Orig, F.getIntrinsicID());
+    Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
     break;
   }
 
diff --git a/llvm/test/CodeGen/DirectX/all.ll b/llvm/test/CodeGen/DirectX/all.ll
new file mode 100644
index 00000000000000..c82d14f05ee640
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/all.ll
@@ -0,0 +1,113 @@
+; RUN: opt -S -passes=dxil-intrinsic-expansion,dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library < %s | FileCheck %s
+
+; Make sure dxil operation function calls for all are generated for float and half.
+
+; CHECK-LABEL: all_bool
+; CHECK: icmp ne i1 %{{.*}}, false
+define noundef i1 @all_bool(i1 noundef %p0) {
+entry:
+  %p0.addr = alloca i8, align 1
+  %frombool = zext i1 %p0 to i8
+  store i8 %frombool, ptr %p0.addr, align 1
+  %0 = load i8, ptr %p0.addr, align 1
+  %tobool = trunc i8 %0 to i1
+  %dx.all = call i1 @llvm.dx.all.i1(i1 %tobool)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int64_t
+; CHECK: icmp ne i64 %{{.*}}, 0
+define noundef i1 @all_int64_t(i64 noundef %p0) {
+entry:
+  %p0.addr = alloca i64, align 8
+  store i64 %p0, ptr %p0.addr, align 8
+  %0 = load i64, ptr %p0.addr, align 8
+  %dx.all = call i1 @llvm.dx.all.i64(i64 %0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int
+; CHECK: icmp ne i32 %{{.*}}, 0
+define noundef i1 @all_int(i32 noundef %p0) {
+entry:
+  %p0.addr = alloca i32, align 4
+  store i32 %p0, ptr %p0.addr, align 4
+  %0 = load i32, ptr %p0.addr, align 4
+  %dx.all = call i1 @llvm.dx.all.i32(i32 %0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int16_t
+; CHECK: icmp ne i16 %{{.*}}, 0
+define noundef i1 @all_int16_t(i16 noundef %p0) {
+entry:
+  %p0.addr = alloca i16, align 2
+  store i16 %p0, ptr %p0.addr, align 2
+  %0 = load i16, ptr %p0.addr, align 2
+  %dx.all = call i1 @llvm.dx.all.i16(i16 %0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_double
+; CHECK: fcmp une double %{{.*}}, 0.000000e+00
+define noundef i1 @all_double(double noundef %p0) {
+entry:
+  %p0.addr = alloca double, align 8
+  store double %p0, ptr %p0.addr, align 8
+  %0 = load double, ptr %p0.addr, align 8
+  %dx.all = call i1 @llvm.dx.all.f64(double %0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_float
+; CHECK: fcmp une float %{{.*}}, 0.000000e+00
+define noundef i1 @all_float(float noundef %p0) {
+entry:
+  %p0.addr = alloca float, align 4
+  store float %p0, ptr %p0.addr, align 4
+  %0 = load float, ptr %p0.addr, align 4
+  %dx.all = call i1 @llvm.dx.all.f32(float %0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_half
+; CHECK: fcmp une half %{{.*}}, 0xH0000
+define noundef i1 @all_half(half noundef %p0) {
+entry:
+  %p0.addr = alloca half, align 2
+  store half %p0, ptr %p0.addr, align 2
+  %0 = load half, ptr %p0.addr, align 2
+  %dx.all = call i1 @llvm.dx.all.f16(half %0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_bool4
+; CHECK: icmp ne <4 x i1> %extractvec, zeroinitialize
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 0
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 1
+; CHECK: and i1  %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 2
+; CHECK: and i1  %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 3
+; CHECK: and i1  %{{.*}}, %{{.*}}
+define noundef i1 @all_bool4(<4 x i1> noundef %p0) {
+entry:
+  %p0.addr = alloca i8, align 1
+  %insertvec = shufflevector <4 x i1> %p0, <4 x i1> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison>
+  %0 = bitcast <8 x i1> %insertvec to i8
+  store i8 %0, ptr %p0.addr, align 1
+  %load_bits = load i8, ptr %p0.addr, align 1
+  %1 = bitcast i8 %load_bits to <8 x i1>
+  %extractvec = shufflevector <8 x i1> %1, <8 x i1> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %dx.all = call i1 @llvm.dx.all.v4i1(<4 x i1> %extractvec)
+  ret i1 %dx.all
+}
+
+declare i1 @llvm.dx.all.v4i1(<4 x i1>)
+declare i1 @llvm.dx.all.i1(i1)
+declare i1 @llvm.dx.all.i16(i16)
+declare i1 @llvm.dx.all.i32(i32)
+declare i1 @llvm.dx.all.i64(i64)
+declare i1 @llvm.dx.all.f16(half)
+declare i1 @llvm.dx.all.f32(float)
+declare i1 @llvm.dx.all.f64(double)

@farzonl farzonl linked an issue Aug 23, 2024 that may be closed by this pull request
Copy link
Contributor

@coopp coopp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

Copy link
Contributor

@dmpots dmpots left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, but I think simplifying the tests would be useful.

@farzonl farzonl merged commit ff5816a into llvm:main Aug 26, 2024
6 of 8 checks passed
@farzonl farzonl deleted the feature/dxil-all-intrinsic branch August 26, 2024 17:40
@farzonl farzonl self-assigned this Nov 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

[DXIL][HLSL] add all intrinsic lowering
5 participants