Skip to content

Commit 92fc1eb

Browse files
authored
[HLSL] add loop unroll (#93879)
spec: microsoft/hlsl-specs#263 - `Attr.td` - Define the HLSL loop attribute hints (unroll and loop) - `AttrDocs.td` - Add documentation for unroll and loop - `CGLoopInfo.cpp` - Add codegen for HLSL unroll that maps to clang unroll expectations - `ParseStmt.cpp` - For statements if HLSL define DeclSpecAttrs via MaybeParseMicrosoftAttributes - `SemaStmtAttr.cpp` - Add the HLSL loop unroll handeling resolves #70114 dxc examples: - for loop: https://hlsl.godbolt.org/z/8EK6Pa139 - while loop: https://hlsl.godbolt.org/z/ebr5MvEcK - do while: https://hlsl.godbolt.org/z/be8cedoTs Documentation: ![Screenshot_20240531_143000](https://github.com/llvm/llvm-project/assets/1802579/9da9df9b-68a6-49eb-9d4f-e080aa2eff7f)
1 parent 8901c1c commit 92fc1eb

File tree

7 files changed

+343
-8
lines changed

7 files changed

+343
-8
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4172,6 +4172,18 @@ def LoopHint : Attr {
41724172
let HasCustomParsing = 1;
41734173
}
41744174

4175+
/// The HLSL loop attributes
4176+
def HLSLLoopHint: StmtAttr {
4177+
/// [unroll(directive)]
4178+
/// [loop]
4179+
let Spellings = [Microsoft<"unroll">, Microsoft<"loop">];
4180+
let Args = [UnsignedArgument<"directive", /*opt*/1>];
4181+
let Subjects = SubjectList<[ForStmt, WhileStmt, DoStmt],
4182+
ErrorDiag, "'for', 'while', and 'do' statements">;
4183+
let LangOpts = [HLSL];
4184+
let Documentation = [HLSLLoopHintDocs, HLSLUnrollHintDocs];
4185+
}
4186+
41754187
def CapturedRecord : InheritableAttr {
41764188
// This attribute has no spellings as it is only ever created implicitly.
41774189
let Spellings = [];

clang/include/clang/Basic/AttrDocs.td

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7343,6 +7343,100 @@ where shaders must be compiled into a library and linked at runtime.
73437343
}];
73447344
}
73457345

7346+
def HLSLLoopHintDocs : Documentation {
7347+
let Category = DocCatStmt;
7348+
let Heading = "[loop]";
7349+
let Content = [{
7350+
The ``[loop]`` directive allows loop optimization hints to be
7351+
specified for the subsequent loop. The directive allows unrolling to
7352+
be disabled and is not compatible with [unroll(x)].
7353+
7354+
Specifying the parameter, ``[loop]``, directs the
7355+
unroller to not unroll the loop.
7356+
7357+
.. code-block:: hlsl
7358+
7359+
[loop]
7360+
for (...) {
7361+
...
7362+
}
7363+
7364+
.. code-block:: hlsl
7365+
7366+
[loop]
7367+
while (...) {
7368+
...
7369+
}
7370+
7371+
.. code-block:: hlsl
7372+
7373+
[loop]
7374+
do {
7375+
...
7376+
} while (...)
7377+
7378+
See `hlsl loop extensions <https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-for>`_
7379+
for details.
7380+
}];
7381+
}
7382+
7383+
def HLSLUnrollHintDocs : Documentation {
7384+
let Category = DocCatStmt;
7385+
let Heading = "[unroll(x)], [unroll]";
7386+
let Content = [{
7387+
Loop unrolling optimization hints can be specified with ``[unroll(x)]``
7388+
. The attribute is placed immediately before a for, while,
7389+
or do-while.
7390+
Specifying the parameter, ``[unroll(_value_)]``, directs the
7391+
unroller to unroll the loop ``_value_`` times. Note: [unroll(x)] is not compatible with [loop].
7392+
7393+
.. code-block:: hlsl
7394+
7395+
[unroll(4)]
7396+
for (...) {
7397+
...
7398+
}
7399+
7400+
.. code-block:: hlsl
7401+
7402+
[unroll]
7403+
for (...) {
7404+
...
7405+
}
7406+
7407+
.. code-block:: hlsl
7408+
7409+
[unroll(4)]
7410+
while (...) {
7411+
...
7412+
}
7413+
7414+
.. code-block:: hlsl
7415+
7416+
[unroll]
7417+
while (...) {
7418+
...
7419+
}
7420+
7421+
.. code-block:: hlsl
7422+
7423+
[unroll(4)]
7424+
do {
7425+
...
7426+
} while (...)
7427+
7428+
.. code-block:: hlsl
7429+
7430+
[unroll]
7431+
do {
7432+
...
7433+
} while (...)
7434+
7435+
See `hlsl loop extensions <https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-for>`_
7436+
for details.
7437+
}];
7438+
}
7439+
73467440
def ClangRandomizeLayoutDocs : Documentation {
73477441
let Category = DocCatDecl;
73487442
let Heading = "randomize_layout, no_randomize_layout";
@@ -7402,7 +7496,8 @@ b for constant buffer views (CBV).
74027496

74037497
Register space is specified in the format ``space[number]`` and defaults to ``space0`` if omitted.
74047498
Here're resource binding examples with and without space:
7405-
.. code-block:: c++
7499+
7500+
.. code-block:: hlsl
74067501

74077502
RWBuffer<float> Uav : register(u3, space1);
74087503
Buffer<float> Buf : register(t1);
@@ -7420,7 +7515,7 @@ A subcomponent is a register number, which is an integer. A component is in the
74207515

74217516
Examples:
74227517

7423-
.. code-block:: c++
7518+
.. code-block:: hlsl
74247519

74257520
cbuffer A {
74267521
float3 a : packoffset(c0.y);

clang/lib/CodeGen/CGLoopInfo.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,9 +612,9 @@ void LoopInfoStack::push(BasicBlock *Header, clang::ASTContext &Ctx,
612612
const LoopHintAttr *LH = dyn_cast<LoopHintAttr>(Attr);
613613
const OpenCLUnrollHintAttr *OpenCLHint =
614614
dyn_cast<OpenCLUnrollHintAttr>(Attr);
615-
615+
const HLSLLoopHintAttr *HLSLLoopHint = dyn_cast<HLSLLoopHintAttr>(Attr);
616616
// Skip non loop hint attributes
617-
if (!LH && !OpenCLHint) {
617+
if (!LH && !OpenCLHint && !HLSLLoopHint) {
618618
continue;
619619
}
620620

@@ -635,6 +635,17 @@ void LoopInfoStack::push(BasicBlock *Header, clang::ASTContext &Ctx,
635635
Option = LoopHintAttr::UnrollCount;
636636
State = LoopHintAttr::Numeric;
637637
}
638+
} else if (HLSLLoopHint) {
639+
ValueInt = HLSLLoopHint->getDirective();
640+
if (HLSLLoopHint->getSemanticSpelling() ==
641+
HLSLLoopHintAttr::Spelling::Microsoft_unroll) {
642+
if (ValueInt == 0)
643+
State = LoopHintAttr::Enable;
644+
if (ValueInt > 0) {
645+
Option = LoopHintAttr::UnrollCount;
646+
State = LoopHintAttr::Numeric;
647+
}
648+
}
638649
} else if (LH) {
639650
auto *ValueExpr = LH->getValue();
640651
if (ValueExpr) {

clang/lib/Parse/ParseStmt.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,21 @@ Parser::ParseStatementOrDeclaration(StmtVector &Stmts,
114114
// here because we don't want to allow arbitrary orderings.
115115
ParsedAttributes CXX11Attrs(AttrFactory);
116116
MaybeParseCXX11Attributes(CXX11Attrs, /*MightBeObjCMessageSend*/ true);
117-
ParsedAttributes GNUAttrs(AttrFactory);
117+
ParsedAttributes GNUOrMSAttrs(AttrFactory);
118118
if (getLangOpts().OpenCL)
119-
MaybeParseGNUAttributes(GNUAttrs);
119+
MaybeParseGNUAttributes(GNUOrMSAttrs);
120+
121+
if (getLangOpts().HLSL)
122+
MaybeParseMicrosoftAttributes(GNUOrMSAttrs);
120123

121124
StmtResult Res = ParseStatementOrDeclarationAfterAttributes(
122-
Stmts, StmtCtx, TrailingElseLoc, CXX11Attrs, GNUAttrs);
125+
Stmts, StmtCtx, TrailingElseLoc, CXX11Attrs, GNUOrMSAttrs);
123126
MaybeDestroyTemplateIds();
124127

125128
// Attributes that are left should all go on the statement, so concatenate the
126129
// two lists.
127130
ParsedAttributes Attrs(AttrFactory);
128-
takeAndConcatenateAttrs(CXX11Attrs, GNUAttrs, Attrs);
131+
takeAndConcatenateAttrs(CXX11Attrs, GNUOrMSAttrs, Attrs);
129132

130133
assert((Attrs.empty() || Res.isInvalid() || Res.isUsable()) &&
131134
"attributes on empty statement");

clang/lib/Sema/SemaStmtAttr.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "clang/Basic/TargetInfo.h"
1717
#include "clang/Sema/DelayedDiagnostic.h"
1818
#include "clang/Sema/Lookup.h"
19+
#include "clang/Sema/ParsedAttr.h"
1920
#include "clang/Sema/ScopeInfo.h"
2021
#include "clang/Sema/SemaInternal.h"
2122
#include "llvm/ADT/StringExtras.h"
@@ -584,6 +585,39 @@ static Attr *handleOpenCLUnrollHint(Sema &S, Stmt *St, const ParsedAttr &A,
584585
return ::new (S.Context) OpenCLUnrollHintAttr(S.Context, A, UnrollFactor);
585586
}
586587

588+
static Attr *handleHLSLLoopHintAttr(Sema &S, Stmt *St, const ParsedAttr &A,
589+
SourceRange Range) {
590+
591+
if (A.getSemanticSpelling() == HLSLLoopHintAttr::Spelling::Microsoft_loop &&
592+
!A.checkAtMostNumArgs(S, 0))
593+
return nullptr;
594+
595+
unsigned UnrollFactor = 0;
596+
if (A.getNumArgs() == 1) {
597+
598+
if (A.isArgIdent(0)) {
599+
S.Diag(A.getLoc(), diag::err_attribute_argument_type)
600+
<< A << AANT_ArgumentIntegerConstant << A.getRange();
601+
return nullptr;
602+
}
603+
604+
Expr *E = A.getArgAsExpr(0);
605+
606+
if (S.CheckLoopHintExpr(E, St->getBeginLoc(),
607+
/*AllowZero=*/false))
608+
return nullptr;
609+
610+
std::optional<llvm::APSInt> ArgVal = E->getIntegerConstantExpr(S.Context);
611+
// CheckLoopHintExpr handles non int const cases
612+
assert(ArgVal != std::nullopt && "ArgVal should be an integer constant.");
613+
int Val = ArgVal->getSExtValue();
614+
// CheckLoopHintExpr handles negative and zero cases
615+
assert(Val > 0 && "Val should be a positive integer greater than zero.");
616+
UnrollFactor = static_cast<unsigned>(Val);
617+
}
618+
return ::new (S.Context) HLSLLoopHintAttr(S.Context, A, UnrollFactor);
619+
}
620+
587621
static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
588622
SourceRange Range) {
589623
if (A.isInvalid() || A.getKind() == ParsedAttr::IgnoredAttribute)
@@ -618,6 +652,8 @@ static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
618652
return handleFallThroughAttr(S, St, A, Range);
619653
case ParsedAttr::AT_LoopHint:
620654
return handleLoopHintAttr(S, St, A, Range);
655+
case ParsedAttr::AT_HLSLLoopHint:
656+
return handleHLSLLoopHintAttr(S, St, A, Range);
621657
case ParsedAttr::AT_OpenCLUnrollHint:
622658
return handleOpenCLUnrollHint(S, St, A, Range);
623659
case ParsedAttr::AT_Suppress:
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
2+
// RUN: dxil-pc-shadermodel6.3-library -disable-llvm-passes %s -emit-llvm -o - | FileCheck %s
3+
4+
/*** for ***/
5+
void for_count()
6+
{
7+
// CHECK-LABEL: for_count
8+
[unroll(8)]
9+
for( int i = 0; i < 1000; ++i);
10+
// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_DISTINCT:.*]]
11+
}
12+
13+
void for_disable()
14+
{
15+
// CHECK-LABEL: for_disable
16+
[loop]
17+
for( int i = 0; i < 1000; ++i);
18+
// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_DISABLE:.*]]
19+
}
20+
21+
void for_enable()
22+
{
23+
// CHECK-LABEL: for_enable
24+
[unroll]
25+
for( int i = 0; i < 1000; ++i);
26+
// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_ENABLE:.*]]
27+
}
28+
29+
void for_nested_one_unroll_enable()
30+
{
31+
// CHECK-LABEL: for_nested_one_unroll_enable
32+
int s = 0;
33+
[unroll]
34+
for( int i = 0; i < 1000; ++i) {
35+
for( int j = 0; j < 10; ++j)
36+
s += i + j;
37+
}
38+
// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_NESTED_ENABLE:.*]]
39+
// CHECK-NOT: br label %{{.*}}, !llvm.loop ![[FOR_NESTED_1_ENABLE:.*]]
40+
}
41+
42+
void for_nested_two_unroll_enable()
43+
{
44+
// CHECK-LABEL: for_nested_two_unroll_enable
45+
int s = 0;
46+
[unroll]
47+
for( int i = 0; i < 1000; ++i) {
48+
[unroll]
49+
for( int j = 0; j < 10; ++j)
50+
s += i + j;
51+
}
52+
// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_NESTED2_ENABLE:.*]]
53+
// CHECK: br label %{{.*}}, !llvm.loop ![[FOR_NESTED2_1_ENABLE:.*]]
54+
}
55+
56+
57+
/*** while ***/
58+
void while_count()
59+
{
60+
// CHECK-LABEL: while_count
61+
int i = 1000;
62+
[unroll(4)]
63+
while(i-->0);
64+
// CHECK: br label %{{.*}}, !llvm.loop ![[WHILE_DISTINCT:.*]]
65+
}
66+
67+
void while_disable()
68+
{
69+
// CHECK-LABEL: while_disable
70+
int i = 1000;
71+
[loop]
72+
while(i-->0);
73+
// CHECK: br label %{{.*}}, !llvm.loop ![[WHILE_DISABLE:.*]]
74+
}
75+
76+
void while_enable()
77+
{
78+
// CHECK-LABEL: while_enable
79+
int i = 1000;
80+
[unroll]
81+
while(i-->0);
82+
// CHECK: br label %{{.*}}, !llvm.loop ![[WHILE_ENABLE:.*]]
83+
}
84+
85+
/*** do ***/
86+
void do_count()
87+
{
88+
// CHECK-LABEL: do_count
89+
int i = 1000;
90+
[unroll(16)]
91+
do {} while(i--> 0);
92+
// CHECK: br i1 %{{.*}}, label %{{.*}}, label %{{.*}}, !llvm.loop ![[DO_DISTINCT:.*]]
93+
}
94+
95+
void do_disable()
96+
{
97+
// CHECK-LABEL: do_disable
98+
int i = 1000;
99+
[loop]
100+
do {} while(i--> 0);
101+
// CHECK: br i1 %{{.*}}, label %{{.*}}, label %{{.*}}, !llvm.loop ![[DO_DISABLE:.*]]
102+
}
103+
104+
void do_enable()
105+
{
106+
// CHECK-LABEL: do_enable
107+
int i = 1000;
108+
[unroll]
109+
do {} while(i--> 0);
110+
// CHECK: br i1 %{{.*}}, label %{{.*}}, label %{{.*}}, !llvm.loop ![[DO_ENABLE:.*]]
111+
}
112+
113+
114+
// CHECK: ![[FOR_DISTINCT]] = distinct !{![[FOR_DISTINCT]], ![[FOR_COUNT:.*]]}
115+
// CHECK: ![[FOR_COUNT]] = !{!"llvm.loop.unroll.count", i32 8}
116+
// CHECK: ![[FOR_DISABLE]] = distinct !{![[FOR_DISABLE]], ![[DISABLE:.*]]}
117+
// CHECK: ![[DISABLE]] = !{!"llvm.loop.unroll.disable"}
118+
// CHECK: ![[FOR_ENABLE]] = distinct !{![[FOR_ENABLE]], ![[ENABLE:.*]]}
119+
// CHECK: ![[ENABLE]] = !{!"llvm.loop.unroll.enable"}
120+
// CHECK: ![[FOR_NESTED_ENABLE]] = distinct !{![[FOR_NESTED_ENABLE]], ![[ENABLE]]}
121+
// CHECK: ![[FOR_NESTED2_ENABLE]] = distinct !{![[FOR_NESTED2_ENABLE]], ![[ENABLE]]}
122+
// CHECK: ![[FOR_NESTED2_1_ENABLE]] = distinct !{![[FOR_NESTED2_1_ENABLE]], ![[ENABLE]]}
123+
// CHECK: ![[WHILE_DISTINCT]] = distinct !{![[WHILE_DISTINCT]], ![[WHILE_COUNT:.*]]}
124+
// CHECK: ![[WHILE_COUNT]] = !{!"llvm.loop.unroll.count", i32 4}
125+
// CHECK: ![[WHILE_DISABLE]] = distinct !{![[WHILE_DISABLE]], ![[DISABLE]]}
126+
// CHECK: ![[WHILE_ENABLE]] = distinct !{![[WHILE_ENABLE]], ![[ENABLE]]}
127+
// CHECK: ![[DO_DISTINCT]] = distinct !{![[DO_DISTINCT]], ![[DO_COUNT:.*]]}
128+
// CHECK: ![[DO_COUNT]] = !{!"llvm.loop.unroll.count", i32 16}
129+
// CHECK: ![[DO_DISABLE]] = distinct !{![[DO_DISABLE]], ![[DISABLE]]}
130+
// CHECK: ![[DO_ENABLE]] = distinct !{![[DO_ENABLE]], ![[ENABLE]]}

0 commit comments

Comments
 (0)