Skip to content

Commit 4ae12ca

Browse files
committed
[LegacyPM][DirectX] Add the scalarizer pass for DXIL legalization
1 parent 8287831 commit 4ae12ca

File tree

8 files changed

+62
-9
lines changed

8 files changed

+62
-9
lines changed

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ void initializeSafepointIRVerifierPass(PassRegistry &);
276276
void initializeSelectOptimizePass(PassRegistry &);
277277
void initializeScalarEvolutionWrapperPassPass(PassRegistry &);
278278
void initializeScalarizeMaskedMemIntrinLegacyPassPass(PassRegistry &);
279+
void initializeScalarizerLegacyPassPass(PassRegistry&);
279280
void initializeScavengerTestPass(PassRegistry &);
280281
void initializeScopedNoAliasAAWrapperPassPass(PassRegistry &);
281282
void initializeSeparateConstOffsetFromGEPLegacyPassPass(PassRegistry &);

llvm/include/llvm/LinkAllPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ struct ForcePassLinking {
130130
(void)llvm::createLowerAtomicPass();
131131
(void)llvm::createLoadStoreVectorizerPass();
132132
(void)llvm::createPartiallyInlineLibCallsPass();
133+
(void)llvm::createScalarizerPass();
133134
(void)llvm::createSeparateConstOffsetFromGEPPass();
134135
(void)llvm::createSpeculativeExecutionPass();
135136
(void)llvm::createSpeculativeExecutionIfHasBranchDivergencePass();

llvm/include/llvm/Transforms/Scalar/Scalarizer.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#define LLVM_TRANSFORMS_SCALAR_SCALARIZER_H
2020

2121
#include "llvm/IR/PassManager.h"
22+
#include "llvm/Pass.h"
2223
#include <optional>
2324

2425
namespace llvm {
@@ -50,6 +51,19 @@ class ScalarizerPass : public PassInfoMixin<ScalarizerPass> {
5051
void setScalarizeLoadStore(bool Value) { Options.ScalarizeLoadStore = Value; }
5152
void setScalarizeMinBits(unsigned Value) { Options.ScalarizeMinBits = Value; }
5253
};
54+
55+
/// Create a legacy pass manager instance of the Scalarizer pass
56+
FunctionPass *createScalarizerPass();
57+
58+
class ScalarizerLegacyPass : public FunctionPass {
59+
public:
60+
static char ID;
61+
ScalarizerPassOptions Options;
62+
ScalarizerLegacyPass();
63+
bool runOnFunction(Function &F) override;
64+
void getAnalysisUsage(AnalysisUsage& AU) const override;
65+
};
66+
5367
}
5468

5569
#endif /* LLVM_TRANSFORMS_SCALAR_SCALARIZER_H */

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "llvm/InitializePasses.h"
2525
#include "llvm/Pass.h"
2626
#include "llvm/Support/ErrorHandling.h"
27+
#include "llvm/Transforms/Scalar/Scalarizer.h"
2728

2829
#define DEBUG_TYPE "dxil-op-lower"
2930

@@ -521,6 +522,7 @@ class DXILOpLoweringLegacy : public ModulePass {
521522
static char ID; // Pass identification.
522523
void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
523524
AU.addRequired<DXILIntrinsicExpansionLegacy>();
525+
AU.addRequired<ScalarizerLegacyPass>();
524526
AU.addRequired<DXILResourceWrapperPass>();
525527
AU.addPreserved<DXILResourceWrapperPass>();
526528
}

llvm/lib/Target/DirectX/DirectXTargetMachine.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "llvm/CodeGen/MachineModuleInfo.h"
2727
#include "llvm/CodeGen/Passes.h"
2828
#include "llvm/CodeGen/TargetPassConfig.h"
29+
#include "llvm/InitializePasses.h"
2930
#include "llvm/IR/IRPrintingPasses.h"
3031
#include "llvm/IR/LegacyPassManager.h"
3132
#include "llvm/MC/MCSectionDXContainer.h"
@@ -36,6 +37,7 @@
3637
#include "llvm/Support/Compiler.h"
3738
#include "llvm/Support/ErrorHandling.h"
3839
#include "llvm/Target/TargetLoweringObjectFile.h"
40+
#include "llvm/Transforms/Scalar/Scalarizer.h"
3941
#include <optional>
4042

4143
using namespace llvm;
@@ -44,6 +46,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
4446
RegisterTargetMachine<DirectXTargetMachine> X(getTheDirectXTarget());
4547
auto *PR = PassRegistry::getPassRegistry();
4648
initializeDXILIntrinsicExpansionLegacyPass(*PR);
49+
initializeScalarizerLegacyPassPass(*PR);
4750
initializeDXILPrepareModulePass(*PR);
4851
initializeEmbedDXILPassPass(*PR);
4952
initializeWriteDXILPassPass(*PR);
@@ -83,6 +86,7 @@ class DirectXPassConfig : public TargetPassConfig {
8386
FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; }
8487
void addCodeGenPrepare() override {
8588
addPass(createDXILIntrinsicExpansionLegacyPass());
89+
addPass(createScalarizerPass());
8690
addPass(createDXILOpLoweringLegacyPass());
8791
addPass(createDXILFinalizeLinkageLegacyPass());
8892
addPass(createDXILTranslateMetadataLegacyPass());

llvm/lib/Transforms/Scalar/Scalar.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ using namespace llvm;
2121
void llvm::initializeScalarOpts(PassRegistry &Registry) {
2222
initializeConstantHoistingLegacyPassPass(Registry);
2323
initializeDCELegacyPassPass(Registry);
24+
initializeScalarizerLegacyPassPass(Registry);
2425
initializeGVNLegacyPassPass(Registry);
2526
initializeEarlyCSELegacyPassPass(Registry);
2627
initializeEarlyCSEMemSSALegacyPassPass(Registry);

llvm/lib/Transforms/Scalar/Scalarizer.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "llvm/IR/Module.h"
3737
#include "llvm/IR/Type.h"
3838
#include "llvm/IR/Value.h"
39+
#include "llvm/InitializePasses.h"
3940
#include "llvm/Support/Casting.h"
4041
#include "llvm/Support/CommandLine.h"
4142
#include "llvm/Transforms/Utils/Local.h"
@@ -339,9 +340,25 @@ class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
339340
const bool ScalarizeLoadStore;
340341
const unsigned ScalarizeMinBits;
341342
};
342-
343343
} // end anonymous namespace
344344

345+
ScalarizerLegacyPass::ScalarizerLegacyPass() : FunctionPass(ID) {
346+
Options.ScalarizeVariableInsertExtract = true;
347+
Options.ScalarizeLoadStore = true;
348+
}
349+
350+
void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage& AU) const {
351+
AU.addRequired<DominatorTreeWrapperPass>();
352+
AU.addPreserved<DominatorTreeWrapperPass>();
353+
}
354+
355+
char ScalarizerLegacyPass::ID = 0;
356+
INITIALIZE_PASS_BEGIN(ScalarizerLegacyPass, "scalarizer",
357+
"Scalarize vector operations", false, false)
358+
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
359+
INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer",
360+
"Scalarize vector operations", false, false)
361+
345362
Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
346363
const VectorSplit &VS, ValueVector *cachePtr)
347364
: BB(bb), BBI(bbi), V(v), VS(VS), CachePtr(cachePtr) {
@@ -414,6 +431,19 @@ Value *Scatterer::operator[](unsigned Frag) {
414431
return CV[Frag];
415432
}
416433

434+
bool ScalarizerLegacyPass::runOnFunction(Function &F) {
435+
if (skipFunction(F))
436+
return false;
437+
438+
DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
439+
ScalarizerVisitor Impl(DT, Options);
440+
return Impl.visit(F);
441+
}
442+
443+
FunctionPass *llvm::createScalarizerPass() {
444+
return new ScalarizerLegacyPass();
445+
}
446+
417447
bool ScalarizerVisitor::visit(Function &F) {
418448
assert(Gathered.empty() && Scattered.empty());
419449

llvm/test/CodeGen/DirectX/sin.ll

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,19 @@
77
; Function Attrs: noinline nounwind optnone
88
define noundef float @sin_float(float noundef %a) #0 {
99
entry:
10-
%a.addr = alloca float, align 4
11-
store float %a, ptr %a.addr, align 4
12-
%0 = load float, ptr %a.addr, align 4
13-
%1 = call float @llvm.sin.f32(float %0)
10+
%1 = call float @llvm.sin.f32(float %a)
1411
ret float %1
1512
}
1613

1714
; Function Attrs: noinline nounwind optnone
1815
define noundef half @sin_half(half noundef %a) #0 {
1916
entry:
20-
%a.addr = alloca half, align 2
21-
store half %a, ptr %a.addr, align 2
22-
%0 = load half, ptr %a.addr, align 2
23-
%1 = call half @llvm.sin.f16(half %0)
17+
%1 = call half @llvm.sin.f16(half %a)
2418
ret half %1
2519
}
20+
21+
define noundef <4 x float> @sin_float4(<4 x float> noundef %a) #0 {
22+
entry:
23+
%2 = call <4 x float> @llvm.sin.v4f32(<4 x float> %a)
24+
ret <4 x float> %2
25+
}

0 commit comments

Comments
 (0)