Skip to content

Commit 9b3d85f

Browse files
authored
[DirectX] TypedUAVLoadAdditionalFormats shader flag (#120477)
Set the TypedUAVLoadAddtionalFormats flag if the shader contains a load from a multicomponent UAV. Fixes #114557
1 parent b905bcc commit 9b3d85f

File tree

3 files changed

+88
-12
lines changed

3 files changed

+88
-12
lines changed

llvm/lib/Target/DirectX/DXILShaderFlags.cpp

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,21 @@
1414
#include "DXILShaderFlags.h"
1515
#include "DirectX.h"
1616
#include "llvm/ADT/STLExtras.h"
17+
#include "llvm/Analysis/DXILResource.h"
1718
#include "llvm/IR/Instruction.h"
19+
#include "llvm/IR/IntrinsicInst.h"
20+
#include "llvm/IR/Intrinsics.h"
21+
#include "llvm/IR/IntrinsicsDirectX.h"
1822
#include "llvm/IR/Module.h"
23+
#include "llvm/InitializePasses.h"
1924
#include "llvm/Support/FormatVariadic.h"
2025
#include "llvm/Support/raw_ostream.h"
2126

2227
using namespace llvm;
2328
using namespace llvm::dxil;
2429

25-
static void updateFunctionFlags(ComputedShaderFlags &CSF,
26-
const Instruction &I) {
30+
static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
31+
DXILResourceTypeMap &DRTM) {
2732
if (!CSF.Doubles)
2833
CSF.Doubles = I.getType()->isDoubleTy();
2934

@@ -44,9 +49,23 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF,
4449
break;
4550
}
4651
}
52+
53+
if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
54+
switch (II->getIntrinsicID()) {
55+
default:
56+
break;
57+
case Intrinsic::dx_typedBufferLoad: {
58+
dxil::ResourceTypeInfo &RTI =
59+
DRTM[cast<TargetExtType>(II->getArgOperand(0)->getType())];
60+
if (RTI.isTyped())
61+
CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1;
62+
}
63+
}
64+
}
4765
}
4866

49-
void ModuleShaderFlags::initialize(const Module &M) {
67+
void ModuleShaderFlags::initialize(const Module &M, DXILResourceTypeMap &DRTM) {
68+
5069
// Collect shader flags for each of the functions
5170
for (const auto &F : M.getFunctionList()) {
5271
if (F.isDeclaration()) {
@@ -57,7 +76,7 @@ void ModuleShaderFlags::initialize(const Module &M) {
5776
ComputedShaderFlags CSF;
5877
for (const auto &BB : F)
5978
for (const auto &I : BB)
60-
updateFunctionFlags(CSF, I);
79+
updateFunctionFlags(CSF, I, DRTM);
6180
// Insert shader flag mask for function F
6281
FunctionFlags.push_back({&F, CSF});
6382
// Update combined shader flags mask
@@ -104,8 +123,11 @@ AnalysisKey ShaderFlagsAnalysis::Key;
104123

105124
ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M,
106125
ModuleAnalysisManager &AM) {
126+
DXILResourceTypeMap &DRTM = AM.getResult<DXILResourceTypeAnalysis>(M);
127+
107128
ModuleShaderFlags MSFI;
108-
MSFI.initialize(M);
129+
MSFI.initialize(M, DRTM);
130+
109131
return MSFI;
110132
}
111133

@@ -132,11 +154,22 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
132154
// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
133155

134156
bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
135-
MSFI.initialize(M);
157+
DXILResourceTypeMap &DRTM =
158+
getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
159+
160+
MSFI.initialize(M, DRTM);
136161
return false;
137162
}
138163

164+
void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
165+
AU.setPreservesAll();
166+
AU.addRequiredTransitive<DXILResourceTypeWrapperPass>();
167+
}
168+
139169
char ShaderFlagsAnalysisWrapper::ID = 0;
140170

141-
INITIALIZE_PASS(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
142-
"DXIL Shader Flag Analysis", true, true)
171+
INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
172+
"DXIL Shader Flag Analysis", true, true)
173+
INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
174+
INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
175+
"DXIL Shader Flag Analysis", true, true)

llvm/lib/Target/DirectX/DXILShaderFlags.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
namespace llvm {
2727
class Module;
2828
class GlobalVariable;
29+
class DXILResourceTypeMap;
2930

3031
namespace dxil {
3132

@@ -84,7 +85,7 @@ struct ComputedShaderFlags {
8485
};
8586

8687
struct ModuleShaderFlags {
87-
void initialize(const Module &);
88+
void initialize(const Module &, DXILResourceTypeMap &DRTM);
8889
const ComputedShaderFlags &getFunctionFlags(const Function *) const;
8990
const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
9091

@@ -135,9 +136,7 @@ class ShaderFlagsAnalysisWrapper : public ModulePass {
135136

136137
bool runOnModule(Module &M) override;
137138

138-
void getAnalysisUsage(AnalysisUsage &AU) const override {
139-
AU.setPreservesAll();
140-
}
139+
void getAnalysisUsage(AnalysisUsage &AU) const override;
141140
};
142141

143142
} // namespace dxil
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
2+
; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=CHECK-OBJ
3+
4+
target triple = "dxil-pc-shadermodel6.7-library"
5+
6+
; CHECK-OBJ: - Name: SFI0
7+
; CHECK-OBJ: Flags:
8+
; CHECK-OBJ: TypedUAVLoadAdditionalFormats: true
9+
10+
; CHECK: Combined Shader Flags for Module
11+
; CHECK-NEXT: Shader Flags Value: 0x00002000
12+
13+
; CHECK: Note: shader requires additional functionality:
14+
; CHECK: Typed UAV Load Additional Formats
15+
16+
; CHECK: Function multicomponent : 0x00002000
17+
define <4 x float> @multicomponent() #0 {
18+
%res = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
19+
@llvm.dx.handle.fromBinding(i32 0, i32 0, i32 1, i32 0, i1 false)
20+
%val = call <4 x float> @llvm.dx.typedBufferLoad(
21+
target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %res, i32 0)
22+
ret <4 x float> %val
23+
}
24+
25+
; CHECK: Function onecomponent : 0x00000000
26+
define float @onecomponent() #0 {
27+
%res = call target("dx.TypedBuffer", float, 1, 0, 0)
28+
@llvm.dx.handle.fromBinding(i32 0, i32 0, i32 1, i32 0, i1 false)
29+
%val = call float @llvm.dx.typedBufferLoad(
30+
target("dx.TypedBuffer", float, 1, 0, 0) %res, i32 0)
31+
ret float %val
32+
}
33+
34+
; CHECK: Function noload : 0x00000000
35+
define void @noload(<4 x float> %val) #0 {
36+
%res = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
37+
@llvm.dx.handle.fromBinding(i32 0, i32 0, i32 1, i32 0, i1 false)
38+
call void @llvm.dx.typedBufferStore(
39+
target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %res, i32 0,
40+
<4 x float> %val)
41+
ret void
42+
}
43+
44+
attributes #0 = { convergent norecurse nounwind "hlsl.export"}

0 commit comments

Comments
 (0)