Skip to content

Commit a2fbc9a

Browse files
authored
[DirectX] Start the creation of a DXIL Instruction legalizer (#131221)
- Legalize i8 truncation back to original types - remove sext and truncs - Legalize i64 indicies for insert\extract elements to i32 indicies - fixes #126323 - fixes #129757
1 parent 092e255 commit a2fbc9a

12 files changed

+402
-26
lines changed

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ add_llvm_target(DirectXCodeGen
3232
DXILShaderFlags.cpp
3333
DXILTranslateMetadata.cpp
3434
DXILRootSignature.cpp
35+
DXILLegalizePass.cpp
3536

3637
LINK_COMPONENTS
3738
Analysis
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
//===- DXILLegalizePass.cpp - Legalizes llvm IR for DXIL ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------------===//
8+
//===---------------------------------------------------------------------===//
9+
///
10+
/// \file This file contains a pass to remove i8 truncations and i64 extract
11+
/// and insert elements.
12+
///
13+
//===----------------------------------------------------------------------===//
14+
#include "DXILLegalizePass.h"
15+
#include "DirectX.h"
16+
#include "llvm/IR/Function.h"
17+
#include "llvm/IR/IRBuilder.h"
18+
#include "llvm/IR/InstIterator.h"
19+
#include "llvm/IR/Instruction.h"
20+
#include "llvm/Pass.h"
21+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
22+
#include <functional>
23+
#include <map>
24+
#include <stack>
25+
#include <vector>
26+
27+
#define DEBUG_TYPE "dxil-legalize"
28+
29+
using namespace llvm;
30+
namespace {
31+
32+
static void fixI8TruncUseChain(Instruction &I,
33+
std::stack<Instruction *> &ToRemove,
34+
std::map<Value *, Value *> &ReplacedValues) {
35+
36+
auto *Cmp = dyn_cast<CmpInst>(&I);
37+
38+
if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
39+
if (Trunc->getDestTy()->isIntegerTy(8)) {
40+
ReplacedValues[Trunc] = Trunc->getOperand(0);
41+
ToRemove.push(Trunc);
42+
}
43+
} else if (I.getType()->isIntegerTy(8) ||
44+
(Cmp && Cmp->getOperand(0)->getType()->isIntegerTy(8))) {
45+
IRBuilder<> Builder(&I);
46+
47+
std::vector<Value *> NewOperands;
48+
Type *InstrType = IntegerType::get(I.getContext(), 32);
49+
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
50+
Value *Op = I.getOperand(OpIdx);
51+
if (ReplacedValues.count(Op))
52+
InstrType = ReplacedValues[Op]->getType();
53+
}
54+
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
55+
Value *Op = I.getOperand(OpIdx);
56+
if (ReplacedValues.count(Op))
57+
NewOperands.push_back(ReplacedValues[Op]);
58+
else if (auto *Imm = dyn_cast<ConstantInt>(Op)) {
59+
APInt Value = Imm->getValue();
60+
unsigned NewBitWidth = InstrType->getIntegerBitWidth();
61+
// Note: options here are sext or sextOrTrunc.
62+
// Since i8 isn't supported, we assume new values
63+
// will always have a higher bitness.
64+
APInt NewValue = Value.sext(NewBitWidth);
65+
NewOperands.push_back(ConstantInt::get(InstrType, NewValue));
66+
} else {
67+
assert(!Op->getType()->isIntegerTy(8));
68+
NewOperands.push_back(Op);
69+
}
70+
}
71+
72+
Value *NewInst = nullptr;
73+
if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
74+
NewInst =
75+
Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
76+
77+
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
78+
if (OBO->hasNoSignedWrap())
79+
cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
80+
if (OBO->hasNoUnsignedWrap())
81+
cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
82+
}
83+
} else if (Cmp) {
84+
NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0],
85+
NewOperands[1]);
86+
Cmp->replaceAllUsesWith(NewInst);
87+
}
88+
89+
if (NewInst) {
90+
ReplacedValues[&I] = NewInst;
91+
ToRemove.push(&I);
92+
}
93+
} else if (auto *Cast = dyn_cast<CastInst>(&I)) {
94+
if (Cast->getSrcTy()->isIntegerTy(8)) {
95+
ToRemove.push(Cast);
96+
Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
97+
}
98+
}
99+
}
100+
101+
static void
102+
downcastI64toI32InsertExtractElements(Instruction &I,
103+
std::stack<Instruction *> &ToRemove,
104+
std::map<Value *, Value *> &) {
105+
106+
if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
107+
Value *Idx = Extract->getIndexOperand();
108+
auto *CI = dyn_cast<ConstantInt>(Idx);
109+
if (CI && CI->getBitWidth() == 64) {
110+
IRBuilder<> Builder(Extract);
111+
int64_t IndexValue = CI->getSExtValue();
112+
auto *Idx32 =
113+
ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
114+
Value *NewExtract = Builder.CreateExtractElement(
115+
Extract->getVectorOperand(), Idx32, Extract->getName());
116+
117+
Extract->replaceAllUsesWith(NewExtract);
118+
ToRemove.push(Extract);
119+
}
120+
}
121+
122+
if (auto *Insert = dyn_cast<InsertElementInst>(&I)) {
123+
Value *Idx = Insert->getOperand(2);
124+
auto *CI = dyn_cast<ConstantInt>(Idx);
125+
if (CI && CI->getBitWidth() == 64) {
126+
int64_t IndexValue = CI->getSExtValue();
127+
auto *Idx32 =
128+
ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
129+
IRBuilder<> Builder(Insert);
130+
Value *Insert32Index = Builder.CreateInsertElement(
131+
Insert->getOperand(0), Insert->getOperand(1), Idx32,
132+
Insert->getName());
133+
134+
Insert->replaceAllUsesWith(Insert32Index);
135+
ToRemove.push(Insert);
136+
}
137+
}
138+
}
139+
140+
class DXILLegalizationPipeline {
141+
142+
public:
143+
DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
144+
145+
bool runLegalizationPipeline(Function &F) {
146+
std::stack<Instruction *> ToRemove;
147+
std::map<Value *, Value *> ReplacedValues;
148+
for (auto &I : instructions(F)) {
149+
for (auto &LegalizationFn : LegalizationPipeline) {
150+
LegalizationFn(I, ToRemove, ReplacedValues);
151+
}
152+
}
153+
bool MadeChanges = !ToRemove.empty();
154+
155+
while (!ToRemove.empty()) {
156+
Instruction *I = ToRemove.top();
157+
I->eraseFromParent();
158+
ToRemove.pop();
159+
}
160+
161+
return MadeChanges;
162+
}
163+
164+
private:
165+
std::vector<std::function<void(Instruction &, std::stack<Instruction *> &,
166+
std::map<Value *, Value *> &)>>
167+
LegalizationPipeline;
168+
169+
void initializeLegalizationPipeline() {
170+
LegalizationPipeline.push_back(fixI8TruncUseChain);
171+
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
172+
}
173+
};
174+
175+
class DXILLegalizeLegacy : public FunctionPass {
176+
177+
public:
178+
bool runOnFunction(Function &F) override;
179+
DXILLegalizeLegacy() : FunctionPass(ID) {}
180+
181+
static char ID; // Pass identification.
182+
};
183+
} // namespace
184+
185+
PreservedAnalyses DXILLegalizePass::run(Function &F,
186+
FunctionAnalysisManager &FAM) {
187+
DXILLegalizationPipeline DXLegalize;
188+
bool MadeChanges = DXLegalize.runLegalizationPipeline(F);
189+
if (!MadeChanges)
190+
return PreservedAnalyses::all();
191+
PreservedAnalyses PA;
192+
return PA;
193+
}
194+
195+
bool DXILLegalizeLegacy::runOnFunction(Function &F) {
196+
DXILLegalizationPipeline DXLegalize;
197+
return DXLegalize.runLegalizationPipeline(F);
198+
}
199+
200+
char DXILLegalizeLegacy::ID = 0;
201+
202+
INITIALIZE_PASS_BEGIN(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false,
203+
false)
204+
INITIALIZE_PASS_END(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false,
205+
false)
206+
207+
FunctionPass *llvm::createDXILLegalizeLegacyPass() {
208+
return new DXILLegalizeLegacy();
209+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- DXILLegalizePass.h - Legalizes llvm IR for DXIL --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_TARGET_DIRECTX_LEGALIZE_H
10+
#define LLVM_TARGET_DIRECTX_LEGALIZE_H
11+
12+
#include "llvm/IR/PassManager.h"
13+
14+
namespace llvm {
15+
16+
class DXILLegalizePass : public PassInfoMixin<DXILLegalizePass> {
17+
public:
18+
PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
19+
};
20+
} // namespace llvm
21+
22+
#endif // LLVM_TARGET_DIRECTX_LEGALIZE_H

llvm/lib/Target/DirectX/DirectX.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ void initializeDXILFlattenArraysLegacyPass(PassRegistry &);
4747
/// Pass to flatten arrays into a one dimensional DXIL legal form
4848
ModulePass *createDXILFlattenArraysLegacyPass();
4949

50+
/// Initializer DXIL legalizationPass
51+
void initializeDXILLegalizeLegacyPass(PassRegistry &);
52+
53+
/// Pass to Legalize DXIL by remove i8 truncations and i64 insert/extract
54+
/// elements
55+
FunctionPass *createDXILLegalizeLegacyPass();
56+
5057
/// Initializer for DXILOpLowering
5158
void initializeDXILOpLoweringLegacyPass(PassRegistry &);
5259

llvm/lib/Target/DirectX/DirectXPassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,5 @@ MODULE_PASS("print<dxil-root-signature>", dxil::RootSignatureAnalysisPrinter(dbg
3838
#define FUNCTION_PASS(NAME, CREATE_PASS)
3939
#endif
4040
FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
41+
FUNCTION_PASS("dxil-legalize", DXILLegalizePass())
4142
#undef FUNCTION_PASS

llvm/lib/Target/DirectX/DirectXTargetMachine.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "DXILDataScalarization.h"
1616
#include "DXILFlattenArrays.h"
1717
#include "DXILIntrinsicExpansion.h"
18+
#include "DXILLegalizePass.h"
1819
#include "DXILOpLowering.h"
1920
#include "DXILPrettyPrinter.h"
2021
#include "DXILResourceAccess.h"
@@ -52,6 +53,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
5253
initializeDXILDataScalarizationLegacyPass(*PR);
5354
initializeDXILFlattenArraysLegacyPass(*PR);
5455
initializeScalarizerLegacyPassPass(*PR);
56+
initializeDXILLegalizeLegacyPass(*PR);
5557
initializeDXILPrepareModulePass(*PR);
5658
initializeEmbedDXILPassPass(*PR);
5759
initializeWriteDXILPassPass(*PR);
@@ -99,6 +101,7 @@ class DirectXPassConfig : public TargetPassConfig {
99101
ScalarizerPassOptions DxilScalarOptions;
100102
DxilScalarOptions.ScalarizeLoadStore = true;
101103
addPass(createScalarizerPass(DxilScalarOptions));
104+
addPass(createDXILLegalizeLegacyPass());
102105
addPass(createDXILTranslateMetadataLegacyPass());
103106
addPass(createDXILOpLoweringLegacyPass());
104107
addPass(createDXILPrepareModulePass());

llvm/test/CodeGen/DirectX/ResourceGlobalElimination.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
; CHECK-LABEL define void @main()
2020
define void @main() local_unnamed_addr #0 {
2121
entry:
22-
; DXOP: %In_h.i1 = call %dx.types.Handle @dx.op.createHandle
23-
; DXOP: %Out_h.i2 = call %dx.types.Handle @dx.op.createHandle
22+
; DXOP: [[In_h_i:%.*]] = call %dx.types.Handle @dx.op.createHandle
23+
; DXOP: [[Out_h_i:%.*]] = call %dx.types.Handle @dx.op.createHandle
2424
%In_h.i = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4f32_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false)
2525
store target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %In_h.i, ptr @In, align 4
2626
%Out_h.i = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4f32_1_0_0t(i32 4, i32 1, i32 1, i32 0, i1 false)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
2+
3+
define noundef <4 x float> @float4_extract(<4 x float> noundef %a) {
4+
entry:
5+
; CHECK: [[ee0:%.*]] = extractelement <4 x float> %a, i32 0
6+
; CHECK: [[ee1:%.*]] = extractelement <4 x float> %a, i32 1
7+
; CHECK: [[ee2:%.*]] = extractelement <4 x float> %a, i32 2
8+
; CHECK: [[ee3:%.*]] = extractelement <4 x float> %a, i32 3
9+
; CHECK: insertelement <4 x float> poison, float [[ee0]], i32 0
10+
; CHECK: insertelement <4 x float> %{{.*}}, float [[ee1]], i32 1
11+
; CHECK: insertelement <4 x float> %{{.*}}, float [[ee2]], i32 2
12+
; CHECK: insertelement <4 x float> %{{.*}}, float [[ee3]], i32 3
13+
14+
%a.i0 = extractelement <4 x float> %a, i64 0
15+
%a.i1 = extractelement <4 x float> %a, i64 1
16+
%a.i2 = extractelement <4 x float> %a, i64 2
17+
%a.i3 = extractelement <4 x float> %a, i64 3
18+
19+
%.upto0 = insertelement <4 x float> poison, float %a.i0, i64 0
20+
%.upto1 = insertelement <4 x float> %.upto0, float %a.i1, i64 1
21+
%.upto2 = insertelement <4 x float> %.upto1, float %a.i2, i64 2
22+
%0 = insertelement <4 x float> %.upto2, float %a.i3, i64 3
23+
ret <4 x float> %0
24+
}

0 commit comments

Comments
 (0)