Skip to content

Commit 8d7a32a

Browse files
authored
[SYCL] Enable template parameter support for loop_unroll attribute (#1060)
Signed-off-by: Viktoria Maksimova <[email protected]>
1 parent c19372e commit 8d7a32a

File tree

8 files changed

+133
-53
lines changed

8 files changed

+133
-53
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,17 +1182,22 @@ def OpenCLUnrollHint : InheritableAttr {
11821182

11831183
def LoopUnrollHint : InheritableAttr {
11841184
let Spellings = [CXX11<"clang","loop_unroll">];
1185-
let Args = [UnsignedArgument<"UnrollHint">];
1185+
let Args = [ExprArgument<"UnrollHintExpr">];
11861186
let LangOpts = [SYCLIsDevice, SYCLIsHost];
1187+
let HasCustomTypeTransform = 1;
11871188
let AdditionalMembers = [{
11881189
static const char *getName() {
11891190
return "loop_unroll";
11901191
}
1191-
std::string getDiagnosticName() const {
1192-
std::string Value = "";
1193-
if (getUnrollHint())
1194-
Value = "(" + std::to_string(getUnrollHint()) + ")";
1195-
return "[[clang::loop_unroll" + Value + "]]";
1192+
std::string getDiagnosticName(const PrintingPolicy &Policy) const {
1193+
std::string ValueName;
1194+
llvm::raw_string_ostream OS(ValueName);
1195+
if (auto *E = getUnrollHintExpr()) {
1196+
OS << "(";
1197+
E->printPretty(OS, nullptr, Policy);
1198+
OS << ")";
1199+
}
1200+
return "[[clang::loop_unroll" + OS.str() + "]]";
11961201
}
11971202
}];
11981203
let Documentation = [LoopUnrollHintDocs];

clang/include/clang/Sema/Sema.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,6 +1673,11 @@ class Sema final {
16731673
FPGALoopAttrT *BuildSYCLIntelFPGALoopAttr(const AttributeCommonInfo &A,
16741674
Expr *E);
16751675

1676+
LoopUnrollHintAttr *BuildLoopUnrollHintAttr(const AttributeCommonInfo &A,
1677+
Expr *E);
1678+
OpenCLUnrollHintAttr *
1679+
BuildOpenCLLoopUnrollHintAttr(const AttributeCommonInfo &A, Expr *E);
1680+
16761681
bool CheckQualifiedFunctionForTypeId(QualType T, SourceLocation Loc);
16771682

16781683
bool CheckFunctionReturnType(QualType T, SourceLocation Loc);

clang/lib/CodeGen/CGLoopInfo.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -706,8 +706,12 @@ void LoopInfoStack::push(BasicBlock *Header, clang::ASTContext &Ctx,
706706
// 1 - disable unroll.
707707
// other positive integer n - unroll by n.
708708
if (OpenCLHint || UnrollHint) {
709-
ValueInt = OpenCLHint ? OpenCLHint->getUnrollHint()
710-
: UnrollHint->getUnrollHint();
709+
ValueInt = 0;
710+
if (OpenCLHint)
711+
ValueInt = OpenCLHint->getUnrollHint();
712+
else if (Expr *E = UnrollHint->getUnrollHintExpr())
713+
ValueInt = E->EvaluateKnownConstInt(Ctx).getSExtValue();
714+
711715
if (ValueInt == 0) {
712716
State = LoopHintAttr::Enable;
713717
} else if (ValueInt != 1) {

clang/lib/Sema/SemaStmtAttr.cpp

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -554,11 +554,47 @@ void CheckForIncompatibleUnrollHintAttributes(
554554
SourceLocation Loc = Range.getBegin();
555555
S.Diag(Loc, diag::err_loop_unroll_compatibility)
556556
<< PragmaUnroll->getDiagnosticName(Policy)
557-
<< AttrUnroll->getDiagnosticName();
557+
<< AttrUnroll->getDiagnosticName(Policy);
558558
}
559559
}
560560

561-
template <typename LoopUnrollAttrT>
561+
static bool CheckLoopUnrollAttrExpr(Sema &S, Expr *E,
562+
const AttributeCommonInfo &A,
563+
unsigned *UnrollFactor = nullptr) {
564+
if (E && !E->isInstantiationDependent()) {
565+
llvm::APSInt ArgVal(32);
566+
567+
if (!E->isIntegerConstantExpr(ArgVal, S.Context))
568+
return S.Diag(E->getExprLoc(), diag::err_attribute_argument_type)
569+
<< A.getAttrName() << AANT_ArgumentIntegerConstant
570+
<< E->getSourceRange();
571+
572+
if (ArgVal.isNonPositive())
573+
return S.Diag(E->getExprLoc(),
574+
diag::err_attribute_requires_positive_integer)
575+
<< A.getAttrName() << /* positive */ 0;
576+
577+
if (UnrollFactor)
578+
*UnrollFactor = ArgVal.getZExtValue();
579+
}
580+
return false;
581+
}
582+
583+
LoopUnrollHintAttr *Sema::BuildLoopUnrollHintAttr(const AttributeCommonInfo &A,
584+
Expr *E) {
585+
return !CheckLoopUnrollAttrExpr(*this, E, A)
586+
? new (Context) LoopUnrollHintAttr(Context, A, E)
587+
: nullptr;
588+
}
589+
590+
OpenCLUnrollHintAttr *
591+
Sema::BuildOpenCLLoopUnrollHintAttr(const AttributeCommonInfo &A, Expr *E) {
592+
unsigned UnrollFactor = 0;
593+
return !CheckLoopUnrollAttrExpr(*this, E, A, &UnrollFactor)
594+
? new (Context) OpenCLUnrollHintAttr(Context, A, UnrollFactor)
595+
: nullptr;
596+
}
597+
562598
static Attr *handleLoopUnrollHint(Sema &S, Stmt *St, const ParsedAttr &A,
563599
SourceRange Range) {
564600
// Although the feature was introduced only in OpenCL C v2.0 s6.11.5, it's
@@ -574,30 +610,13 @@ static Attr *handleLoopUnrollHint(Sema &S, Stmt *St, const ParsedAttr &A,
574610
return nullptr;
575611
}
576612

577-
unsigned UnrollFactor = 0;
578-
579-
if (NumArgs == 1) {
580-
Expr *E = A.getArgAsExpr(0);
581-
llvm::APSInt ArgVal(32);
582-
583-
if (!E->isIntegerConstantExpr(ArgVal, S.Context)) {
584-
S.Diag(A.getLoc(), diag::err_attribute_argument_type)
585-
<< A << AANT_ArgumentIntegerConstant << E->getSourceRange();
586-
return nullptr;
587-
}
588-
589-
int Val = ArgVal.getSExtValue();
590-
591-
if (Val <= 0) {
592-
S.Diag(A.getRange().getBegin(),
593-
diag::err_attribute_requires_positive_integer)
594-
<< A << /* positive */ 0;
595-
return nullptr;
596-
}
597-
UnrollFactor = Val;
598-
}
613+
Expr *E = NumArgs ? A.getArgAsExpr(0) : nullptr;
614+
if (A.getParsedKind() == ParsedAttr::AT_OpenCLUnrollHint)
615+
return S.BuildOpenCLLoopUnrollHintAttr(A, E);
616+
else if (A.getParsedKind() == ParsedAttr::AT_LoopUnrollHint)
617+
return S.BuildLoopUnrollHintAttr(A, E);
599618

600-
return LoopUnrollAttrT::CreateImplicit(S.Context, UnrollFactor);
619+
return nullptr;
601620
}
602621

603622
static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
@@ -620,9 +639,8 @@ static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
620639
case ParsedAttr::AT_SYCLIntelFPGAMaxConcurrency:
621640
return handleIntelFPGALoopAttr<SYCLIntelFPGAMaxConcurrencyAttr>(S, A);
622641
case ParsedAttr::AT_OpenCLUnrollHint:
623-
return handleLoopUnrollHint<OpenCLUnrollHintAttr>(S, St, A, Range);
624642
case ParsedAttr::AT_LoopUnrollHint:
625-
return handleLoopUnrollHint<LoopUnrollHintAttr>(S, St, A, Range);
643+
return handleLoopUnrollHint(S, St, A, Range);
626644
case ParsedAttr::AT_Suppress:
627645
return handleSuppressAttr(S, St, A, Range);
628646
default:

clang/lib/Sema/SemaTemplateInstantiate.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,8 @@ namespace {
10081008
const SYCLIntelFPGAMaxConcurrencyAttr *
10091009
TransformSYCLIntelFPGAMaxConcurrencyAttr(
10101010
const SYCLIntelFPGAMaxConcurrencyAttr *MC);
1011+
const LoopUnrollHintAttr *
1012+
TransformLoopUnrollHintAttr(const LoopUnrollHintAttr *LU);
10111013

10121014
ExprResult TransformPredefinedExpr(PredefinedExpr *E);
10131015
ExprResult TransformDeclRefExpr(DeclRefExpr *E);
@@ -1429,6 +1431,13 @@ TemplateInstantiator::TransformSYCLIntelFPGAMaxConcurrencyAttr(
14291431
*MC, TransformedExpr);
14301432
}
14311433

1434+
const LoopUnrollHintAttr *TemplateInstantiator::TransformLoopUnrollHintAttr(
1435+
const LoopUnrollHintAttr *LU) {
1436+
Expr *TransformedExpr =
1437+
getDerived().TransformExpr(LU->getUnrollHintExpr()).get();
1438+
return getSema().BuildLoopUnrollHintAttr(*LU, TransformedExpr);
1439+
}
1440+
14321441
ExprResult TemplateInstantiator::transformNonTypeTemplateParmRef(
14331442
NonTypeTemplateParmDecl *parm,
14341443
SourceLocation loc,
Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,54 @@
11
// RUN: %clang_cc1 -triple spir64-unknown-unknown-sycldevice -disable-llvm-passes -fsycl-is-device -emit-llvm %s -o - | FileCheck %s
22

3-
// CHECK: br label %for.cond, !llvm.loop ![[COUNT:[0-9]+]]
4-
// CHECK: br label %while.cond, !llvm.loop ![[DISABLE:[0-9]+]]
5-
// CHECK: br i1 %{{.*}}, label %do.body, label %do.end, !llvm.loop ![[ENABLE:[0-9]+]]
63

7-
// CHECK: ![[COUNT]] = distinct !{![[COUNT]], ![[COUNT_A:[0-9]+]]}
8-
// CHECK-NEXT: ![[COUNT_A]] = !{!"llvm.loop.unroll.count", i32 8}
4+
void enable() {
5+
int i = 1000;
6+
// CHECK: br i1 %{{.*}}, label %do.body, label %do.end, !llvm.loop ![[ENABLE:[0-9]+]]
7+
[[clang::loop_unroll]]
8+
do {} while (i--);
9+
}
10+
11+
template <int A>
912
void count() {
13+
// CHECK: br label %for.cond, !llvm.loop ![[COUNT:[0-9]+]]
1014
[[clang::loop_unroll(8)]]
1115
for (int i = 0; i < 1000; ++i);
16+
// CHECK: br label %for.cond2, !llvm.loop ![[COUNT_TEMPLATE:[0-9]+]]
17+
[[clang::loop_unroll(A)]]
18+
for (int i = 0; i < 1000; ++i);
1219
}
1320

14-
// CHECK: ![[DISABLE]] = distinct !{![[DISABLE]], ![[DISABLE_A:[0-9]+]]}
15-
// CHECK-NEXT: ![[DISABLE_A]] = !{!"llvm.loop.unroll.disable"}
21+
template <int A>
1622
void disable() {
17-
int i = 1000;
23+
int i = 1000, j = 100;
24+
// CHECK: br label %while.cond, !llvm.loop ![[DISABLE:[0-9]+]]
1825
[[clang::loop_unroll(1)]]
26+
while (j--);
27+
// CHECK: br label %while.cond1, !llvm.loop ![[DISABLE_TEMPLATE:[0-9]+]]
28+
[[clang::loop_unroll(A)]]
1929
while (i--);
2030
}
2131

22-
// CHECK: ![[ENABLE]] = distinct !{![[ENABLE]], ![[ENABLE_A:[0-9]+]]}
23-
// CHECK-NEXT: ![[ENABLE_A]] = !{!"llvm.loop.unroll.enable"}
24-
void enable() {
25-
int i = 1000;
26-
[[clang::loop_unroll]]
27-
do {} while (i--);
28-
}
29-
3032
template <typename name, typename Func>
3133
__attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {
3234
kernelFunc();
3335
}
3436

3537
int main() {
3638
kernel_single_task<class kernel_function>([]() {
37-
count();
38-
disable();
39+
count<4>();
40+
disable<1>();
3941
enable();
4042
});
4143
return 0;
4244
}
45+
46+
// CHECK: ![[ENABLE]] = distinct !{![[ENABLE]], ![[ENABLE_A:[0-9]+]]}
47+
// CHECK-NEXT: ![[ENABLE_A]] = !{!"llvm.loop.unroll.enable"}
48+
// CHECK: ![[COUNT]] = distinct !{![[COUNT]], ![[COUNT_A:[0-9]+]]}
49+
// CHECK-NEXT: ![[COUNT_A]] = !{!"llvm.loop.unroll.count", i32 8}
50+
// CHECK: ![[COUNT_TEMPLATE]] = distinct !{![[COUNT_TEMPLATE]], ![[COUNT_TEMPLATE_A:[0-9]+]]}
51+
// CHECK-NEXT: ![[COUNT_TEMPLATE_A]] = !{!"llvm.loop.unroll.count", i32 4}
52+
// CHECK: ![[DISABLE]] = distinct !{![[DISABLE]], ![[DISABLE_A:[0-9]+]]}
53+
// CHECK-NEXT: ![[DISABLE_A]] = !{!"llvm.loop.unroll.disable"}
54+
// CHECKL ![[DISABLE_TEMPLATE]] = distinct !{!![[DISABLE_TEMPLATE]], ![[DISABLE_A]]}

clang/test/CodeGenSYCL/loop_unroll_host.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
// CHECK: br label %{{.*}}, !llvm.loop ![[COUNT:[0-9]+]]
33
// CHECK: br label %{{.*}}, !llvm.loop ![[DISABLE:[0-9]+]]
44
// CHECK: br i1 %{{.*}}, label %{{.*}}, label %{{.*}}, !llvm.loop ![[ENABLE:[0-9]+]]
5+
// CHECK: br label %{{.*}}, !llvm.loop ![[COUNT_TEMPLATE:[0-9]+]]
6+
// CHECK: br label %{{.*}}, !llvm.loop ![[DISABLE_TEMPLATE:[0-9]+]]
7+
8+
template <int A>
9+
void unroll() {
10+
[[clang::loop_unroll(A)]]
11+
for (int i = 0; i < 1000; ++i);
12+
}
513

614
int main() {
715
// CHECK: ![[COUNT]] = distinct !{![[COUNT]], ![[COUNT_A:[0-9]+]]}
@@ -18,5 +26,11 @@ int main() {
1826
i = 1000;
1927
[[clang::loop_unroll]]
2028
do {} while (i--);
29+
30+
// CHECK: ![[COUNT_TEMPLATE]] = distinct !{![[COUNT_TEMPLATE]], ![[COUNT_TEMPLATE_A:[0-9]+]]}
31+
// CHECK-NEXT: ![[COUNT_TEMPLATE_A]] = !{!"llvm.loop.unroll.count", i32 8}
32+
unroll<8>();
33+
// CHECK: ![[DISABLE_TEMPLATE]] = distinct !{![[DISABLE_TEMPLATE]], ![[DISABLE_A]]}
34+
unroll<1>();
2135
return 0;
2236
}

clang/test/SemaSYCL/loop_unroll.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
// RUN: %clang_cc1 -fsycl-is-device -fsyntax-only -verify -pedantic %s
22

3+
template <int A>
4+
void bar() {
5+
// expected-error@+1 {{'loop_unroll' attribute requires a positive integral compile time constant expression}}
6+
[[clang::loop_unroll(A)]]
7+
for (int i = 0; i < 10; ++i);
8+
}
9+
310
void foo() {
411
// expected-error@+1 {{clang loop attributes must be applied to for, while, or do statements}}
512
[[clang::loop_unroll(8)]] int a[10];
@@ -44,6 +51,12 @@ void foo() {
4451
constexpr int c = 4;
4552
[[clang::loop_unroll(c)]]
4653
for (int i = 0; i < 10; ++i);
54+
55+
// expected-note@+1 {{in instantiation of function template specialization}}
56+
bar<-4>();
57+
58+
// no error expected
59+
bar<c>();
4760
}
4861

4962
template <typename name, typename Func>

0 commit comments

Comments
 (0)