Skip to content

Commit 7dde8fa

Browse files
committed
[DXIL] exp, any, lerp, & rcp Intrinsic Lowering
This change implements lowering for #70076, #70100, #70072, & #70102 CGBuiltin.cpp - - simplify lerp intrinsic IntrinsicsDirectX.td - simplify lerp intrinsic SemaChecking.cpp - remove unnecessary check DXILIntrinsicExpansion.* - add intrinsic to instruction expansion cases DXILOpLowering.cpp - make sure DXILIntrinsicExpansion happens first DirectX.h - changes to support new pass DirectXTargetMachine.cpp - changes to support new pass
1 parent ee137e2 commit 7dde8fa

File tree

15 files changed

+564
-47
lines changed

15 files changed

+564
-47
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18015,38 +18015,11 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1801518015
Value *X = EmitScalarExpr(E->getArg(0));
1801618016
Value *Y = EmitScalarExpr(E->getArg(1));
1801718017
Value *S = EmitScalarExpr(E->getArg(2));
18018-
llvm::Type *Xty = X->getType();
18019-
llvm::Type *Yty = Y->getType();
18020-
llvm::Type *Sty = S->getType();
18021-
if (!Xty->isVectorTy() && !Yty->isVectorTy() && !Sty->isVectorTy()) {
18022-
if (Xty->isFloatingPointTy()) {
18023-
auto V = Builder.CreateFSub(Y, X);
18024-
V = Builder.CreateFMul(S, V);
18025-
return Builder.CreateFAdd(X, V, "dx.lerp");
18026-
}
18027-
llvm_unreachable("Scalar Lerp is only supported on floats.");
18028-
}
18029-
// A VectorSplat should have happened
18030-
assert(Xty->isVectorTy() && Yty->isVectorTy() && Sty->isVectorTy() &&
18031-
"Lerp of vector and scalar is not supported.");
18032-
18033-
[[maybe_unused]] auto *XVecTy =
18034-
E->getArg(0)->getType()->getAs<VectorType>();
18035-
[[maybe_unused]] auto *YVecTy =
18036-
E->getArg(1)->getType()->getAs<VectorType>();
18037-
[[maybe_unused]] auto *SVecTy =
18038-
E->getArg(2)->getType()->getAs<VectorType>();
18039-
// A HLSLVectorTruncation should have happend
18040-
assert(XVecTy->getNumElements() == YVecTy->getNumElements() &&
18041-
XVecTy->getNumElements() == SVecTy->getNumElements() &&
18042-
"Lerp requires vectors to be of the same size.");
18043-
assert(XVecTy->getElementType()->isRealFloatingType() &&
18044-
XVecTy->getElementType() == YVecTy->getElementType() &&
18045-
XVecTy->getElementType() == SVecTy->getElementType() &&
18046-
"Lerp requires float vectors to be of the same type.");
18018+
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
18019+
llvm_unreachable("lerp operand must have a float representation");
1804718020
return Builder.CreateIntrinsic(
18048-
/*ReturnType=*/Xty, Intrinsic::dx_lerp, ArrayRef<Value *>{X, Y, S},
18049-
nullptr, "dx.lerp");
18021+
/*ReturnType=*/X->getType(), Intrinsic::dx_lerp,
18022+
ArrayRef<Value *>{X, Y, S}, nullptr, "dx.lerp");
1805018023
}
1805118024
case Builtin::BI__builtin_hlsl_elementwise_frac: {
1805218025
Value *Op0 = EmitScalarExpr(E->getArg(0));

clang/lib/Sema/SemaChecking.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5300,8 +5300,6 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
53005300
return true;
53015301
if (SemaBuiltinElementwiseTernaryMath(TheCall))
53025302
return true;
5303-
if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
5304-
return true;
53055303
break;
53065304
}
53075305
case Builtin::BI__builtin_hlsl_mad: {

clang/test/CodeGenHLSL/builtins/lerp.hlsl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@
66
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
77
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
88

9-
// NATIVE_HALF: %3 = fsub half %1, %0
10-
// NATIVE_HALF: %4 = fmul half %2, %3
11-
// NATIVE_HALF: %dx.lerp = fadd half %0, %4
9+
10+
// NATIVE_HALF: %dx.lerp = call half @llvm.dx.lerp.f16(half %0, half %1, half %2)
1211
// NATIVE_HALF: ret half %dx.lerp
13-
// NO_HALF: %3 = fsub float %1, %0
14-
// NO_HALF: %4 = fmul float %2, %3
15-
// NO_HALF: %dx.lerp = fadd float %0, %4
12+
// NO_HALF: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
1613
// NO_HALF: ret float %dx.lerp
1714
half test_lerp_half(half p0) { return lerp(p0, p0, p0); }
1815

@@ -34,9 +31,7 @@ half3 test_lerp_half3(half3 p0, half3 p1) { return lerp(p0, p0, p0); }
3431
// NO_HALF: ret <4 x float> %dx.lerp
3532
half4 test_lerp_half4(half4 p0, half4 p1) { return lerp(p0, p0, p0); }
3633

37-
// CHECK: %3 = fsub float %1, %0
38-
// CHECK: %4 = fmul float %2, %3
39-
// CHECK: %dx.lerp = fadd float %0, %4
34+
// CHECK: %dx.lerp = call float @llvm.dx.lerp.f32(float %0, float %1, float %2)
4035
// CHECK: ret float %dx.lerp
4136
float test_lerp_float(float p0, float p1) { return lerp(p0, p0, p0); }
4237

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,7 @@ def int_dx_dot :
2929

3030
def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
3131

32-
def int_dx_lerp :
33-
Intrinsic<[LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
34-
[llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>,LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
35-
[IntrNoMem, IntrWillReturn] >;
32+
def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>], [IntrNoMem, IntrWillReturn] >;
3633

3734
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
3835
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_llvm_target(DirectXCodeGen
1919
DirectXSubtarget.cpp
2020
DirectXTargetMachine.cpp
2121
DXContainerGlobals.cpp
22+
DXILIntrinsicExpansion.cpp
2223
DXILMetadata.cpp
2324
DXILOpBuilder.cpp
2425
DXILOpLowering.cpp
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
//===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file This file contains DXIL intrinsic expansions for those that don't have
10+
// opcodes in DirectX Intermediate Language (DXIL).
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "DXILIntrinsicExpansion.h"
14+
#include "DirectX.h"
15+
#include "llvm/ADT/STLExtras.h"
16+
#include "llvm/ADT/SmallVector.h"
17+
#include "llvm/CodeGen/Passes.h"
18+
#include "llvm/IR/IRBuilder.h"
19+
#include "llvm/IR/Instruction.h"
20+
#include "llvm/IR/Instructions.h"
21+
#include "llvm/IR/Intrinsics.h"
22+
#include "llvm/IR/IntrinsicsDirectX.h"
23+
#include "llvm/IR/Module.h"
24+
#include "llvm/IR/PassManager.h"
25+
#include "llvm/IR/Type.h"
26+
#include "llvm/Pass.h"
27+
#include "llvm/Support/ErrorHandling.h"
28+
29+
#define DEBUG_TYPE "dxil-intrinsic-expansion"
30+
#define M_LOG2E_F 1.44269504088896340735992468100189214f
31+
32+
using namespace llvm;
33+
34+
static bool isIntrinsicExpansion(Function &F) {
35+
switch (F.getIntrinsicID()) {
36+
case Intrinsic::exp:
37+
case Intrinsic::dx_any:
38+
case Intrinsic::dx_lerp:
39+
case Intrinsic::dx_rcp:
40+
return true;
41+
}
42+
return false;
43+
}
44+
45+
static bool expandExpIntrinsic(CallInst *Orig) {
46+
Value *X = Orig->getOperand(0);
47+
IRBuilder<> Builder(Orig->getParent());
48+
Builder.SetInsertPoint(Orig);
49+
Type *Ty = X->getType();
50+
Type *EltTy = Ty->getScalarType();
51+
Constant *Log2eConst =
52+
Ty->isVectorTy()
53+
? ConstantVector::getSplat(
54+
ElementCount::getFixed(
55+
dyn_cast<FixedVectorType>(Ty)->getNumElements()),
56+
ConstantFP::get(EltTy, M_LOG2E_F))
57+
: ConstantFP::get(EltTy, M_LOG2E_F);
58+
Value *NewX = Builder.CreateFMul(Log2eConst, X);
59+
auto *Exp2Call =
60+
Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
61+
Exp2Call->setTailCall(Orig->isTailCall());
62+
Exp2Call->setAttributes(Orig->getAttributes());
63+
Orig->replaceAllUsesWith(Exp2Call);
64+
Orig->eraseFromParent();
65+
return true;
66+
}
67+
68+
static bool expandAnyIntrinsic(CallInst *Orig) {
69+
Value *X = Orig->getOperand(0);
70+
IRBuilder<> Builder(Orig->getParent());
71+
Builder.SetInsertPoint(Orig);
72+
Type *Ty = X->getType();
73+
Type *EltTy = Ty->getScalarType();
74+
75+
if (!Ty->isVectorTy()) {
76+
Value *Cond = EltTy->isFloatingPointTy()
77+
? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
78+
: Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
79+
Orig->replaceAllUsesWith(Cond);
80+
} else {
81+
auto *XVec = dyn_cast<FixedVectorType>(Ty);
82+
Value *Cond =
83+
EltTy->isFloatingPointTy()
84+
? Builder.CreateFCmpUNE(
85+
X, ConstantVector::getSplat(
86+
ElementCount::getFixed(XVec->getNumElements()),
87+
ConstantFP::get(EltTy, 0)))
88+
: Builder.CreateICmpNE(
89+
X, ConstantVector::getSplat(
90+
ElementCount::getFixed(XVec->getNumElements()),
91+
ConstantInt::get(EltTy, 0)));
92+
Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
93+
for (unsigned I = 1; I < XVec->getNumElements(); I++) {
94+
Value *Elt = Builder.CreateExtractElement(Cond, I);
95+
Result = Builder.CreateOr(Result, Elt);
96+
}
97+
Orig->replaceAllUsesWith(Result);
98+
}
99+
Orig->eraseFromParent();
100+
return true;
101+
}
102+
103+
static bool expandLerpIntrinsic(CallInst *Orig) {
104+
Value *X = Orig->getOperand(0);
105+
Value *Y = Orig->getOperand(1);
106+
Value *S = Orig->getOperand(2);
107+
IRBuilder<> Builder(Orig->getParent());
108+
Builder.SetInsertPoint(Orig);
109+
auto *V = Builder.CreateFSub(Y, X);
110+
V = Builder.CreateFMul(S, V);
111+
auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
112+
Orig->replaceAllUsesWith(Result);
113+
Orig->eraseFromParent();
114+
return true;
115+
}
116+
117+
static bool expandReciprocalIntrinsic(CallInst *Orig) {
118+
Value *X = Orig->getOperand(0);
119+
IRBuilder<> Builder(Orig->getParent());
120+
Builder.SetInsertPoint(Orig);
121+
Type *Ty = X->getType();
122+
Type *EltTy = Ty->getScalarType();
123+
Constant *One =
124+
Ty->isVectorTy()
125+
? ConstantVector::getSplat(
126+
ElementCount::getFixed(
127+
dyn_cast<FixedVectorType>(Ty)->getNumElements()),
128+
ConstantFP::get(EltTy, 1.0))
129+
: ConstantFP::get(EltTy, 1.0);
130+
auto *Result = Builder.CreateFDiv(One, X, "dx.rcp");
131+
Orig->replaceAllUsesWith(Result);
132+
Orig->eraseFromParent();
133+
return true;
134+
}
135+
136+
static bool expandIntrinsic(Function &F, CallInst *Orig) {
137+
switch (F.getIntrinsicID()) {
138+
case Intrinsic::exp:
139+
return expandExpIntrinsic(Orig);
140+
case Intrinsic::dx_any:
141+
return expandAnyIntrinsic(Orig);
142+
case Intrinsic::dx_lerp:
143+
return expandLerpIntrinsic(Orig);
144+
case Intrinsic::dx_rcp:
145+
return expandReciprocalIntrinsic(Orig);
146+
}
147+
return false;
148+
}
149+
150+
static bool intrinsicExpansion(Module &M) {
151+
for (auto &F : make_early_inc_range(M.functions())) {
152+
if (!isIntrinsicExpansion(F))
153+
continue;
154+
bool IntrinsicExpanded = false;
155+
for (User *U : make_early_inc_range(F.users())) {
156+
auto *IntrinsicCall = dyn_cast<CallInst>(U);
157+
if (!IntrinsicCall)
158+
continue;
159+
IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
160+
}
161+
if (F.user_empty() && IntrinsicExpanded)
162+
F.eraseFromParent();
163+
}
164+
return true;
165+
}
166+
167+
PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
168+
ModuleAnalysisManager &) {
169+
if (intrinsicExpansion(M))
170+
return PreservedAnalyses::none();
171+
return PreservedAnalyses::all();
172+
}
173+
174+
bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
175+
return intrinsicExpansion(M);
176+
}
177+
178+
char DXILIntrinsicExpansionLegacy::ID = 0;
179+
180+
INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
181+
"DXIL Intrinsic Expansion", false, false)
182+
INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
183+
"DXIL Intrinsic Expansion", false, false)
184+
185+
ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
186+
return new DXILIntrinsicExpansionLegacy();
187+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//===- DXILIntrinsicExpansion.h - Prepare LLVM Module for DXIL encoding----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
9+
#define LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H
10+
11+
#include "DXILResource.h"
12+
#include "llvm/IR/PassManager.h"
13+
#include "llvm/Pass.h"
14+
15+
namespace llvm {
16+
17+
/// A pass that transforms DXIL Intrinsics that don't have DXIL opCodes
18+
class DXILIntrinsicExpansion : public PassInfoMixin<DXILIntrinsicExpansion> {
19+
public:
20+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
21+
};
22+
23+
class DXILIntrinsicExpansionLegacy : public ModulePass {
24+
25+
public:
26+
bool runOnModule(Module &M) override;
27+
DXILIntrinsicExpansionLegacy() : ModulePass(ID) {}
28+
29+
static char ID; // Pass identification.
30+
};
31+
} // namespace llvm
32+
33+
#endif // LLVM_TARGET_DIRECTX_DXILINTRINSICEXPANSION_H

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "DXILConstants.h"
14+
#include "DXILIntrinsicExpansion.h"
1415
#include "DXILOpBuilder.h"
1516
#include "DirectX.h"
1617
#include "llvm/ADT/SmallVector.h"
@@ -94,9 +95,12 @@ class DXILOpLoweringLegacy : public ModulePass {
9495
DXILOpLoweringLegacy() : ModulePass(ID) {}
9596

9697
static char ID; // Pass identification.
98+
void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
99+
// Specify the passes that your pass depends on
100+
AU.addRequired<DXILIntrinsicExpansionLegacy>();
101+
}
97102
};
98103
char DXILOpLoweringLegacy::ID = 0;
99-
100104
} // end anonymous namespace
101105

102106
INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",

llvm/lib/Target/DirectX/DirectX.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ void initializeDXILPrepareModulePass(PassRegistry &);
2828
/// Pass to convert modules into DXIL-compatable modules
2929
ModulePass *createDXILPrepareModulePass();
3030

31+
/// Initializer for DXIL Intrinsic Expansion
32+
void initializeDXILIntrinsicExpansionLegacyPass(PassRegistry &);
33+
34+
/// Pass to expand intrinsic operations that lack DXIL opCodes
35+
ModulePass *createDXILIntrinsicExpansionLegacyPass();
36+
3137
/// Initializer for DXILOpLowering
3238
void initializeDXILOpLoweringLegacyPass(PassRegistry &);
3339

llvm/lib/Target/DirectX/DirectXTargetMachine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ using namespace llvm;
3939
extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
4040
RegisterTargetMachine<DirectXTargetMachine> X(getTheDirectXTarget());
4141
auto *PR = PassRegistry::getPassRegistry();
42+
initializeDXILIntrinsicExpansionLegacyPass(*PR);
4243
initializeDXILPrepareModulePass(*PR);
4344
initializeEmbedDXILPassPass(*PR);
4445
initializeWriteDXILPassPass(*PR);
@@ -76,6 +77,7 @@ class DirectXPassConfig : public TargetPassConfig {
7677

7778
FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; }
7879
void addCodeGenPrepare() override {
80+
addPass(createDXILIntrinsicExpansionLegacyPass());
7981
addPass(createDXILOpLoweringLegacyPass());
8082
addPass(createDXILPrepareModulePass());
8183
addPass(createDXILTranslateMetadataPass());

0 commit comments

Comments
 (0)