Skip to content

Commit 85285be

Browse files
committed
[DirectX backend] Add pass to lower llvm intrinsic into dxil op function.
A new pass DXILOpLowering was added. It will scan all llvm intrinsics, create dxil op function if it can map to dxil op function. Then translate call instructions on the intrinsic into call on dxil op function. dxil op function will add i32 argument to the begining of args for dxil opcode. So cannot use setCalledFunction to update the call instruction on intrinsic. This commit only support sin to start the work. Reviewed By: kuhar, beanz Differential Revision: https://reviews.llvm.org/D124805
1 parent 4537aae commit 85285be

File tree

7 files changed

+362
-1
lines changed

7 files changed

+362
-1
lines changed

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_public_tablegen_target(DirectXCommonTableGen)
99
add_llvm_target(DirectXCodeGen
1010
DirectXSubtarget.cpp
1111
DirectXTargetMachine.cpp
12+
DXILOpLowering.cpp
1213
DXILPointerType.cpp
1314
DXILPrepare.cpp
1415
PointerTypeAnalysis.cpp
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- DXILConstants.h - Essential DXIL constants -------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file This file contains essential DXIL constants.
10+
//===----------------------------------------------------------------------===//
11+
12+
#ifndef LLVM_LIB_TARGET_DIRECTX_DXILCONSTANTS_H
13+
#define LLVM_LIB_TARGET_DIRECTX_DXILCONSTANTS_H
14+
15+
namespace llvm {
16+
namespace DXIL {
17+
// Enumeration for operations specified by DXIL
18+
enum class OpCode : unsigned {
19+
Sin = 13, // returns sine(theta) for theta in radians.
20+
};
21+
// Groups for DXIL operations with equivalent function templates
22+
enum class OpCodeClass : unsigned {
23+
Unary,
24+
};
25+
26+
} // namespace DXIL
27+
} // namespace llvm
28+
29+
#endif
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
//===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file This file contains passes and utilities to lower llvm intrinsic call
10+
/// to DXILOp function call.
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "DXILConstants.h"
14+
#include "DirectX.h"
15+
#include "llvm/ADT/SmallVector.h"
16+
#include "llvm/CodeGen/Passes.h"
17+
#include "llvm/IR/IRBuilder.h"
18+
#include "llvm/IR/Instruction.h"
19+
#include "llvm/IR/Intrinsics.h"
20+
#include "llvm/IR/Module.h"
21+
#include "llvm/IR/PassManager.h"
22+
#include "llvm/Pass.h"
23+
#include "llvm/Support/ErrorHandling.h"
24+
25+
#define DEBUG_TYPE "dxil-op-lower"
26+
27+
using namespace llvm;
28+
using namespace llvm::DXIL;
29+
30+
constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
31+
32+
enum OverloadKind : uint16_t {
33+
VOID = 1,
34+
HALF = 1 << 1,
35+
FLOAT = 1 << 2,
36+
DOUBLE = 1 << 3,
37+
I1 = 1 << 4,
38+
I8 = 1 << 5,
39+
I16 = 1 << 6,
40+
I32 = 1 << 7,
41+
I64 = 1 << 8,
42+
UserDefineType = 1 << 9,
43+
ObjectType = 1 << 10,
44+
};
45+
46+
static const char *getOverloadTypeName(OverloadKind Kind) {
47+
switch (Kind) {
48+
case OverloadKind::HALF:
49+
return "f16";
50+
case OverloadKind::FLOAT:
51+
return "f32";
52+
case OverloadKind::DOUBLE:
53+
return "f64";
54+
case OverloadKind::I1:
55+
return "i1";
56+
case OverloadKind::I8:
57+
return "i8";
58+
case OverloadKind::I16:
59+
return "i16";
60+
case OverloadKind::I32:
61+
return "i32";
62+
case OverloadKind::I64:
63+
return "i64";
64+
case OverloadKind::VOID:
65+
case OverloadKind::ObjectType:
66+
case OverloadKind::UserDefineType:
67+
llvm_unreachable("invalid overload type for name");
68+
break;
69+
}
70+
}
71+
72+
static OverloadKind getOverloadKind(Type *Ty) {
73+
Type::TypeID T = Ty->getTypeID();
74+
switch (T) {
75+
case Type::VoidTyID:
76+
return OverloadKind::VOID;
77+
case Type::HalfTyID:
78+
return OverloadKind::HALF;
79+
case Type::FloatTyID:
80+
return OverloadKind::FLOAT;
81+
case Type::DoubleTyID:
82+
return OverloadKind::DOUBLE;
83+
case Type::IntegerTyID: {
84+
IntegerType *ITy = cast<IntegerType>(Ty);
85+
unsigned Bits = ITy->getBitWidth();
86+
switch (Bits) {
87+
case 1:
88+
return OverloadKind::I1;
89+
case 8:
90+
return OverloadKind::I8;
91+
case 16:
92+
return OverloadKind::I16;
93+
case 32:
94+
return OverloadKind::I32;
95+
case 64:
96+
return OverloadKind::I64;
97+
default:
98+
llvm_unreachable("invalid overload type");
99+
return OverloadKind::VOID;
100+
}
101+
}
102+
case Type::PointerTyID:
103+
return OverloadKind::UserDefineType;
104+
case Type::StructTyID:
105+
return OverloadKind::ObjectType;
106+
default:
107+
llvm_unreachable("invalid overload type");
108+
return OverloadKind::VOID;
109+
}
110+
}
111+
112+
static std::string getTypeName(OverloadKind Kind, Type *Ty) {
113+
if (Kind < OverloadKind::UserDefineType) {
114+
return getOverloadTypeName(Kind);
115+
} else if (Kind == OverloadKind::UserDefineType) {
116+
StructType *ST = cast<StructType>(Ty);
117+
return ST->getStructName().str();
118+
} else if (Kind == OverloadKind::ObjectType) {
119+
StructType *ST = cast<StructType>(Ty);
120+
return ST->getStructName().str();
121+
} else {
122+
std::string Str;
123+
raw_string_ostream OS(Str);
124+
Ty->print(OS);
125+
return OS.str();
126+
}
127+
}
128+
129+
// Static properties.
130+
struct OpCodeProperty {
131+
DXIL::OpCode OpCode;
132+
// FIXME: change OpCodeName into index to a large string constant when move to
133+
// tableGen.
134+
const char *OpCodeName;
135+
DXIL::OpCodeClass OpCodeClass;
136+
uint16_t OverloadTys;
137+
llvm::Attribute::AttrKind FuncAttr;
138+
};
139+
140+
static const char *getOpCodeClassName(const OpCodeProperty &Prop) {
141+
// FIXME: generate this table with tableGen.
142+
static const char *OpCodeClassNames[] = {
143+
"unary",
144+
};
145+
unsigned Index = static_cast<unsigned>(Prop.OpCodeClass);
146+
assert(Index < (sizeof(OpCodeClassNames) / sizeof(OpCodeClassNames[0])) &&
147+
"Out of bound OpCodeClass");
148+
return OpCodeClassNames[Index];
149+
}
150+
151+
static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
152+
const OpCodeProperty &Prop) {
153+
if (Kind == OverloadKind::VOID) {
154+
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
155+
}
156+
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
157+
getTypeName(Kind, Ty))
158+
.str();
159+
}
160+
161+
static const OpCodeProperty *getOpCodeProperty(DXIL::OpCode DXILOp) {
162+
// FIXME: generate this table with tableGen.
163+
static const OpCodeProperty OpCodeProps[] = {
164+
{DXIL::OpCode::Sin, "Sin", OpCodeClass::Unary,
165+
OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone},
166+
};
167+
// FIXME: change search to indexing with
168+
// DXILOp once all DXIL op is added.
169+
OpCodeProperty TmpProp;
170+
TmpProp.OpCode = DXILOp;
171+
const OpCodeProperty *Prop =
172+
llvm::lower_bound(OpCodeProps, TmpProp,
173+
[](const OpCodeProperty &A, const OpCodeProperty &B) {
174+
return A.OpCode < B.OpCode;
175+
});
176+
return Prop;
177+
}
178+
179+
static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F,
180+
Module &M) {
181+
const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
182+
183+
// Get return type as overload type for DXILOp.
184+
// Only simple mapping case here, so return type is good enough.
185+
Type *OverloadTy = F.getReturnType();
186+
187+
OverloadKind Kind = getOverloadKind(OverloadTy);
188+
// FIXME: find the issue and report error in clang instead of check it in
189+
// backend.
190+
if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
191+
llvm_unreachable("invalid overload");
192+
}
193+
194+
std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
195+
assert(!M.getFunction(FnName) && "Function already exists");
196+
197+
auto &Ctx = M.getContext();
198+
Type *OpCodeTy = Type::getInt32Ty(Ctx);
199+
200+
SmallVector<Type *> ArgTypes;
201+
// DXIL has i32 opcode as first arg.
202+
ArgTypes.emplace_back(OpCodeTy);
203+
FunctionType *FT = F.getFunctionType();
204+
ArgTypes.append(FT->param_begin(), FT->param_end());
205+
FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false);
206+
return M.getOrInsertFunction(FnName, DXILOpFT);
207+
}
208+
209+
static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) {
210+
auto DXILOpFn = createDXILOpFunction(DXILOp, F, M);
211+
IRBuilder<> B(M.getContext());
212+
Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
213+
for (User *U : make_early_inc_range(F.users())) {
214+
CallInst *CI = dyn_cast<CallInst>(U);
215+
if (!CI)
216+
continue;
217+
218+
SmallVector<Value *> Args;
219+
Args.emplace_back(DXILOpArg);
220+
Args.append(CI->arg_begin(), CI->arg_end());
221+
B.SetInsertPoint(CI);
222+
CallInst *DXILCI = B.CreateCall(DXILOpFn, Args);
223+
CI->replaceAllUsesWith(DXILCI);
224+
CI->eraseFromParent();
225+
}
226+
if (F.user_empty())
227+
F.eraseFromParent();
228+
}
229+
230+
static bool lowerIntrinsics(Module &M) {
231+
bool Updated = false;
232+
static SmallDenseMap<Intrinsic::ID, DXIL::OpCode> LowerMap = {
233+
{Intrinsic::sin, DXIL::OpCode::Sin}};
234+
for (Function &F : make_early_inc_range(M.functions())) {
235+
if (!F.isDeclaration())
236+
continue;
237+
Intrinsic::ID ID = F.getIntrinsicID();
238+
auto LowerIt = LowerMap.find(ID);
239+
if (LowerIt == LowerMap.end())
240+
continue;
241+
lowerIntrinsic(LowerIt->second, F, M);
242+
Updated = true;
243+
}
244+
return Updated;
245+
}
246+
247+
namespace {
248+
/// A pass that transforms external global definitions into declarations.
249+
class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
250+
public:
251+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
252+
if (lowerIntrinsics(M))
253+
return PreservedAnalyses::none();
254+
return PreservedAnalyses::all();
255+
}
256+
};
257+
} // namespace
258+
259+
namespace {
260+
class DXILOpLoweringLegacy : public ModulePass {
261+
public:
262+
bool runOnModule(Module &M) override { return lowerIntrinsics(M); }
263+
StringRef getPassName() const override { return "DXIL Op Lowering"; }
264+
DXILOpLoweringLegacy() : ModulePass(ID) {}
265+
266+
static char ID; // Pass identification.
267+
};
268+
char DXILOpLoweringLegacy::ID = 0;
269+
270+
} // end anonymous namespace
271+
272+
INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
273+
false, false)
274+
INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
275+
false)
276+
277+
ModulePass *llvm::createDXILOpLoweringLegacyPass() {
278+
return new DXILOpLoweringLegacy();
279+
}

llvm/lib/Target/DirectX/DirectX.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ void initializeDXILPrepareModulePass(PassRegistry &);
2323

2424
/// Pass to convert modules into DXIL-compatable modules
2525
ModulePass *createDXILPrepareModulePass();
26+
27+
/// Initializer for DXILOpLowering
28+
void initializeDXILOpLoweringLegacyPass(PassRegistry &);
29+
30+
/// Pass to lowering LLVM intrinsic call to DXIL op function call.
31+
ModulePass *createDXILOpLoweringLegacyPass();
32+
2633
} // namespace llvm
2734

2835
#endif // LLVM_LIB_TARGET_DIRECTX_DIRECTX_H

llvm/lib/Target/DirectX/DirectXTargetMachine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
3434
RegisterTargetMachine<DirectXTargetMachine> X(getTheDirectXTarget());
3535
auto *PR = PassRegistry::getPassRegistry();
3636
initializeDXILPrepareModulePass(*PR);
37+
initializeDXILOpLoweringLegacyPass(*PR);
3738
}
3839

3940
class DXILTargetObjectFile : public TargetLoweringObjectFile {
@@ -84,6 +85,7 @@ bool DirectXTargetMachine::addPassesToEmitFile(
8485
PassManagerBase &PM, raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
8586
CodeGenFileType FileType, bool DisableVerify,
8687
MachineModuleInfoWrapperPass *MMIWP) {
88+
PM.add(createDXILOpLoweringLegacyPass());
8789
PM.add(createDXILPrepareModulePass());
8890
switch (FileType) {
8991
case CGFT_AssemblyFile:

llvm/test/CodeGen/DirectX/sin.ll

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
2+
3+
; Make sure dxil operation function calls for sin are generated for float and half.
4+
; CHECK:call float @dx.op.unary.f32(i32 13, float %{{.*}})
5+
; CHECK:call half @dx.op.unary.f16(i32 13, half %{{.*}})
6+
7+
target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
8+
target triple = "dxil-pc-shadermodel6.7-library"
9+
10+
; Function Attrs: noinline nounwind optnone
11+
define noundef float @_Z3foof(float noundef %a) #0 {
12+
entry:
13+
%a.addr = alloca float, align 4
14+
store float %a, ptr %a.addr, align 4
15+
%0 = load float, ptr %a.addr, align 4
16+
%1 = call float @llvm.sin.f32(float %0)
17+
ret float %1
18+
}
19+
20+
; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
21+
declare float @llvm.sin.f32(float) #1
22+
23+
; Function Attrs: noinline nounwind optnone
24+
define noundef half @_Z3barDh(half noundef %a) #0 {
25+
entry:
26+
%a.addr = alloca half, align 2
27+
store half %a, ptr %a.addr, align 2
28+
%0 = load half, ptr %a.addr, align 2
29+
%1 = call half @llvm.sin.f16(half %0)
30+
ret half %1
31+
}
32+
33+
; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
34+
declare half @llvm.sin.f16(half) #1
35+
36+
attributes #0 = { noinline nounwind optnone "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
37+
attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn }
38+
39+
!llvm.module.flags = !{!0}
40+
!llvm.ident = !{!1}
41+
42+
!0 = !{i32 1, !"wchar_size", i32 4}
43+
!1 = !{!"clang version 15.0.0 (https://github.com/llvm/llvm-project.git 73417c517644db5c419c85c0b3cb6750172fcab5)"}

0 commit comments

Comments
 (0)