Skip to content

Commit 297f0b3

Browse files
[CudaSPIRV] Allow using integral non-type template parameters as attribute args (#131546)
Allow using integral non-type template parameters as attribute arguments of reqd_work_group_size and work_group_size_hint. Test plan: ninja check-all
1 parent d09ecb0 commit 297f0b3

File tree

10 files changed

+221
-41
lines changed

10 files changed

+221
-41
lines changed

clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,12 @@ void SingleWorkItemBarrierCheck::check(const MatchFinder::MatchResult &Result) {
5454
bool IsNDRange = false;
5555
if (MatchedDecl->hasAttr<ReqdWorkGroupSizeAttr>()) {
5656
const auto *Attribute = MatchedDecl->getAttr<ReqdWorkGroupSizeAttr>();
57-
if (Attribute->getXDim() > 1 || Attribute->getYDim() > 1 ||
58-
Attribute->getZDim() > 1)
57+
auto Eval = [&](Expr *E) {
58+
return E->EvaluateKnownConstInt(MatchedDecl->getASTContext())
59+
.getExtValue();
60+
};
61+
if (Eval(Attribute->getXDim()) > 1 || Eval(Attribute->getYDim()) > 1 ||
62+
Eval(Attribute->getZDim()) > 1)
5963
IsNDRange = true;
6064
}
6165
if (IsNDRange) // No warning if kernel is treated as an NDRange.

clang/include/clang/Basic/Attr.td

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3044,18 +3044,15 @@ def NoDeref : TypeAttr {
30443044
def ReqdWorkGroupSize : InheritableAttr {
30453045
// Does not have a [[]] spelling because it is an OpenCL-related attribute.
30463046
let Spellings = [GNU<"reqd_work_group_size">];
3047-
let Args = [UnsignedArgument<"XDim">, UnsignedArgument<"YDim">,
3048-
UnsignedArgument<"ZDim">];
3047+
let Args = [ExprArgument<"XDim">, ExprArgument<"YDim">, ExprArgument<"ZDim">];
30493048
let Subjects = SubjectList<[Function], ErrorDiag>;
30503049
let Documentation = [Undocumented];
30513050
}
30523051

30533052
def WorkGroupSizeHint : InheritableAttr {
30543053
// Does not have a [[]] spelling because it is an OpenCL-related attribute.
30553054
let Spellings = [GNU<"work_group_size_hint">];
3056-
let Args = [UnsignedArgument<"XDim">,
3057-
UnsignedArgument<"YDim">,
3058-
UnsignedArgument<"ZDim">];
3055+
let Args = [ExprArgument<"XDim">, ExprArgument<"YDim">, ExprArgument<"ZDim">];
30593056
let Subjects = SubjectList<[Function], ErrorDiag>;
30603057
let Documentation = [Undocumented];
30613058
}

clang/lib/CodeGen/CodeGenFunction.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -649,18 +649,24 @@ void CodeGenFunction::EmitKernelMetadata(const FunctionDecl *FD,
649649
}
650650

651651
if (const WorkGroupSizeHintAttr *A = FD->getAttr<WorkGroupSizeHintAttr>()) {
652+
auto Eval = [&](Expr *E) {
653+
return E->EvaluateKnownConstInt(FD->getASTContext()).getExtValue();
654+
};
652655
llvm::Metadata *AttrMDArgs[] = {
653-
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getXDim())),
654-
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getYDim())),
655-
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getZDim()))};
656+
llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getXDim()))),
657+
llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getYDim()))),
658+
llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getZDim())))};
656659
Fn->setMetadata("work_group_size_hint", llvm::MDNode::get(Context, AttrMDArgs));
657660
}
658661

659662
if (const ReqdWorkGroupSizeAttr *A = FD->getAttr<ReqdWorkGroupSizeAttr>()) {
663+
auto Eval = [&](Expr *E) {
664+
return E->EvaluateKnownConstInt(FD->getASTContext()).getExtValue();
665+
};
660666
llvm::Metadata *AttrMDArgs[] = {
661-
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getXDim())),
662-
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getYDim())),
663-
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getZDim()))};
667+
llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getXDim()))),
668+
llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getYDim()))),
669+
llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getZDim())))};
664670
Fn->setMetadata("reqd_work_group_size", llvm::MDNode::get(Context, AttrMDArgs));
665671
}
666672

clang/lib/CodeGen/Targets/AMDGPU.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -753,12 +753,16 @@ void CodeGenModule::handleAMDGPUFlatWorkGroupSizeAttr(
753753
int32_t *MaxThreadsVal) {
754754
unsigned Min = 0;
755755
unsigned Max = 0;
756+
auto Eval = [&](Expr *E) {
757+
return E->EvaluateKnownConstInt(getContext()).getExtValue();
758+
};
756759
if (FlatWGS) {
757-
Min = FlatWGS->getMin()->EvaluateKnownConstInt(getContext()).getExtValue();
758-
Max = FlatWGS->getMax()->EvaluateKnownConstInt(getContext()).getExtValue();
760+
Min = Eval(FlatWGS->getMin());
761+
Max = Eval(FlatWGS->getMax());
759762
}
760763
if (ReqdWGS && Min == 0 && Max == 0)
761-
Min = Max = ReqdWGS->getXDim() * ReqdWGS->getYDim() * ReqdWGS->getZDim();
764+
Min = Max = Eval(ReqdWGS->getXDim()) * Eval(ReqdWGS->getYDim()) *
765+
Eval(ReqdWGS->getZDim());
762766

763767
if (Min != 0) {
764768
assert(Min <= Max && "Min must be less than or equal Max");

clang/lib/CodeGen/Targets/TCE.cpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,21 @@ void TCETargetCodeGenInfo::setTargetAttributes(
5050
M.getModule().getOrInsertNamedMetadata(
5151
"opencl.kernel_wg_size_info");
5252

53-
SmallVector<llvm::Metadata *, 5> Operands;
54-
Operands.push_back(llvm::ConstantAsMetadata::get(F));
55-
56-
Operands.push_back(
53+
auto Eval = [&](Expr *E) {
54+
return E->EvaluateKnownConstInt(FD->getASTContext());
55+
};
56+
SmallVector<llvm::Metadata *, 5> Operands{
57+
llvm::ConstantAsMetadata::get(F),
5758
llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
58-
M.Int32Ty, llvm::APInt(32, Attr->getXDim()))));
59-
Operands.push_back(
59+
M.Int32Ty, Eval(Attr->getXDim()))),
6060
llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
61-
M.Int32Ty, llvm::APInt(32, Attr->getYDim()))));
62-
Operands.push_back(
61+
M.Int32Ty, Eval(Attr->getYDim()))),
6362
llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
64-
M.Int32Ty, llvm::APInt(32, Attr->getZDim()))));
65-
66-
// Add a boolean constant operand for "required" (true) or "hint"
67-
// (false) for implementing the work_group_size_hint attr later.
68-
// Currently always true as the hint is not yet implemented.
69-
Operands.push_back(
70-
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getTrue(Context)));
63+
M.Int32Ty, Eval(Attr->getZDim()))),
64+
// Add a boolean constant operand for "required" (true) or "hint"
65+
// (false) for implementing the work_group_size_hint attr later.
66+
// Currently always true as the hint is not yet implemented.
67+
llvm::ConstantAsMetadata::get(llvm::ConstantInt::getTrue(Context))};
7168
OpenCLMetadata->addOperand(llvm::MDNode::get(Context, Operands));
7269
}
7370
}

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2914,32 +2914,93 @@ static void handleWeakImportAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
29142914
D->addAttr(::new (S.Context) WeakImportAttr(S.Context, AL));
29152915
}
29162916

2917+
// Checks whether an argument of launch_bounds-like attribute is
2918+
// acceptable, performs implicit conversion to Rvalue, and returns
2919+
// non-nullptr Expr result on success. Otherwise, it returns nullptr
2920+
// and may output an error.
2921+
template <class Attribute>
2922+
static Expr *makeAttributeArgExpr(Sema &S, Expr *E, const Attribute &Attr,
2923+
const unsigned Idx) {
2924+
if (S.DiagnoseUnexpandedParameterPack(E))
2925+
return nullptr;
2926+
2927+
// Accept template arguments for now as they depend on something else.
2928+
// We'll get to check them when they eventually get instantiated.
2929+
if (E->isValueDependent())
2930+
return E;
2931+
2932+
std::optional<llvm::APSInt> I = llvm::APSInt(64);
2933+
if (!(I = E->getIntegerConstantExpr(S.Context))) {
2934+
S.Diag(E->getExprLoc(), diag::err_attribute_argument_n_type)
2935+
<< &Attr << Idx << AANT_ArgumentIntegerConstant << E->getSourceRange();
2936+
return nullptr;
2937+
}
2938+
// Make sure we can fit it in 32 bits.
2939+
if (!I->isIntN(32)) {
2940+
S.Diag(E->getExprLoc(), diag::err_ice_too_large)
2941+
<< toString(*I, 10, false) << 32 << /* Unsigned */ 1;
2942+
return nullptr;
2943+
}
2944+
if (*I < 0)
2945+
S.Diag(E->getExprLoc(), diag::err_attribute_requires_positive_integer)
2946+
<< &Attr << /*non-negative*/ 1 << E->getSourceRange();
2947+
2948+
// We may need to perform implicit conversion of the argument.
2949+
InitializedEntity Entity = InitializedEntity::InitializeParameter(
2950+
S.Context, S.Context.getConstType(S.Context.IntTy), /*consume*/ false);
2951+
ExprResult ValArg = S.PerformCopyInitialization(Entity, SourceLocation(), E);
2952+
assert(!ValArg.isInvalid() &&
2953+
"Unexpected PerformCopyInitialization() failure.");
2954+
2955+
return ValArg.getAs<Expr>();
2956+
}
2957+
29172958
// Handles reqd_work_group_size and work_group_size_hint.
29182959
template <typename WorkGroupAttr>
29192960
static void handleWorkGroupSize(Sema &S, Decl *D, const ParsedAttr &AL) {
2920-
uint32_t WGSize[3];
2961+
Expr *WGSize[3];
29212962
for (unsigned i = 0; i < 3; ++i) {
2922-
const Expr *E = AL.getArgAsExpr(i);
2923-
if (!S.checkUInt32Argument(AL, E, WGSize[i], i,
2924-
/*StrictlyUnsigned=*/true))
2963+
if (Expr *E = makeAttributeArgExpr(S, AL.getArgAsExpr(i), AL, i))
2964+
WGSize[i] = E;
2965+
else
29252966
return;
29262967
}
29272968

2928-
if (!llvm::all_of(WGSize, [](uint32_t Size) { return Size == 0; })) {
2969+
auto IsZero = [&](Expr *E) {
2970+
if (E->isValueDependent())
2971+
return false;
2972+
std::optional<llvm::APSInt> I = E->getIntegerConstantExpr(S.Context);
2973+
assert(I && "Non-integer constant expr");
2974+
return I->isZero();
2975+
};
2976+
2977+
if (!llvm::all_of(WGSize, IsZero)) {
29292978
for (unsigned i = 0; i < 3; ++i) {
29302979
const Expr *E = AL.getArgAsExpr(i);
2931-
if (WGSize[i] == 0) {
2980+
if (IsZero(WGSize[i])) {
29322981
S.Diag(AL.getLoc(), diag::err_attribute_argument_is_zero)
29332982
<< AL << E->getSourceRange();
29342983
return;
29352984
}
29362985
}
29372986
}
29382987

2988+
auto Equal = [&](Expr *LHS, Expr *RHS) {
2989+
if (LHS->isValueDependent() || RHS->isValueDependent())
2990+
return true;
2991+
std::optional<llvm::APSInt> L = LHS->getIntegerConstantExpr(S.Context);
2992+
assert(L && "Non-integer constant expr");
2993+
std::optional<llvm::APSInt> R = RHS->getIntegerConstantExpr(S.Context);
2994+
assert(L && "Non-integer constant expr");
2995+
return L == R;
2996+
};
2997+
29392998
WorkGroupAttr *Existing = D->getAttr<WorkGroupAttr>();
2940-
if (Existing && !(Existing->getXDim() == WGSize[0] &&
2941-
Existing->getYDim() == WGSize[1] &&
2942-
Existing->getZDim() == WGSize[2]))
2999+
if (Existing &&
3000+
!llvm::equal(std::initializer_list<Expr *>{Existing->getXDim(),
3001+
Existing->getYDim(),
3002+
Existing->getZDim()},
3003+
WGSize, Equal))
29433004
S.Diag(AL.getLoc(), diag::warn_duplicate_attribute) << AL;
29443005

29453006
D->addAttr(::new (S.Context)

clang/lib/Sema/SemaTemplateInstantiateDecl.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,32 @@ static void instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
572572
S.AMDGPU().addAMDGPUFlatWorkGroupSizeAttr(New, Attr, MinExpr, MaxExpr);
573573
}
574574

575+
static void instantiateDependentReqdWorkGroupSizeAttr(
576+
Sema &S, const MultiLevelTemplateArgumentList &TemplateArgs,
577+
const ReqdWorkGroupSizeAttr &Attr, Decl *New) {
578+
// Both min and max expression are constant expressions.
579+
EnterExpressionEvaluationContext Unevaluated(
580+
S, Sema::ExpressionEvaluationContext::ConstantEvaluated);
581+
582+
ExprResult Result = S.SubstExpr(Attr.getXDim(), TemplateArgs);
583+
if (Result.isInvalid())
584+
return;
585+
Expr *X = Result.getAs<Expr>();
586+
587+
Result = S.SubstExpr(Attr.getYDim(), TemplateArgs);
588+
if (Result.isInvalid())
589+
return;
590+
Expr *Y = Result.getAs<Expr>();
591+
592+
Result = S.SubstExpr(Attr.getZDim(), TemplateArgs);
593+
if (Result.isInvalid())
594+
return;
595+
Expr *Z = Result.getAs<Expr>();
596+
597+
ASTContext &Context = S.getASTContext();
598+
New->addAttr(::new (Context) ReqdWorkGroupSizeAttr(Context, Attr, X, Y, Z));
599+
}
600+
575601
ExplicitSpecifier Sema::instantiateExplicitSpecifier(
576602
const MultiLevelTemplateArgumentList &TemplateArgs, ExplicitSpecifier ES) {
577603
if (!ES.getExpr())
@@ -812,6 +838,12 @@ void Sema::InstantiateAttrs(const MultiLevelTemplateArgumentList &TemplateArgs,
812838
continue;
813839
}
814840

841+
if (const auto *ReqdWorkGroupSize =
842+
dyn_cast<ReqdWorkGroupSizeAttr>(TmplAttr)) {
843+
instantiateDependentReqdWorkGroupSizeAttr(*this, TemplateArgs,
844+
*ReqdWorkGroupSize, New);
845+
}
846+
815847
if (const auto *AMDGPUFlatWorkGroupSize =
816848
dyn_cast<AMDGPUFlatWorkGroupSizeAttr>(TmplAttr)) {
817849
instantiateDependentAMDGPUFlatWorkGroupSizeAttr(

clang/test/CodeGenCUDASPIRV/spirv-attrs.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,18 @@ __global__ void vec_type_hint_int() {}
1818
__attribute__((intel_reqd_sub_group_size(64)))
1919
__global__ void intel_reqd_sub_group_size_64() {}
2020

21+
template <unsigned a, unsigned b, unsigned c>
22+
__attribute__((reqd_work_group_size(a, b, c)))
23+
__global__ void reqd_work_group_size_a_b_c() {}
24+
25+
template __global__ void reqd_work_group_size_a_b_c<256,1,1>(void);
2126

2227
// CHECK: define spir_kernel void @_Z26reqd_work_group_size_0_0_0v() #[[ATTR:[0-9]+]] !reqd_work_group_size ![[WG_SIZE_ZEROS:[0-9]+]]
2328
// CHECK: define spir_kernel void @_Z28reqd_work_group_size_128_1_1v() #[[ATTR:[0-9]+]] !reqd_work_group_size ![[WG_SIZE:[0-9]+]]
2429
// CHECK: define spir_kernel void @_Z26work_group_size_hint_2_2_2v() #[[ATTR]] !work_group_size_hint ![[WG_HINT:[0-9]+]]
2530
// CHECK: define spir_kernel void @_Z17vec_type_hint_intv() #[[ATTR]] !vec_type_hint ![[VEC_HINT:[0-9]+]]
2631
// CHECK: define spir_kernel void @_Z28intel_reqd_sub_group_size_64v() #[[ATTR]] !intel_reqd_sub_group_size ![[SUB_GRP:[0-9]+]]
32+
// CHECK: define spir_kernel void @_Z26reqd_work_group_size_a_b_cILj256ELj1ELj1EEvv() #[[ATTR]] comdat !reqd_work_group_size ![[WG_SIZE_TMPL:[0-9]+]]
2733

2834
// CHECK: attributes #[[ATTR]] = { {{.*}} }
2935

@@ -32,3 +38,4 @@ __global__ void intel_reqd_sub_group_size_64() {}
3238
// CHECK: ![[WG_HINT]] = !{i32 2, i32 2, i32 2}
3339
// CHECK: ![[VEC_HINT]] = !{i32 poison, i32 1}
3440
// CHECK: ![[SUB_GRP]] = !{i32 64}
41+
// CHECK: ![[WG_SIZE_TMPL]] = !{i32 256, i32 1, i32 1}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// RUN: %clang_cc1 -triple spirv64 -aux-triple x86_64-unknown-linux-gnu \
2+
// RUN: -fcuda-is-device -verify -fsyntax-only %s
3+
4+
#define __global__ __attribute__((global))
5+
6+
__attribute__((reqd_work_group_size(0x100000000, 1, 1))) // expected-error {{integer constant expression evaluates to value 4294967296 that cannot be represented in a 32-bit unsigned integer type}}
7+
__global__ void TestTooBigArg1(void);
8+
9+
__attribute__((work_group_size_hint(0x100000000, 1, 1))) // expected-error {{integer constant expression evaluates to value 4294967296 that cannot be represented in a 32-bit unsigned integer type}}
10+
__global__ void TestTooBigArg2(void);
11+
12+
template <int... Args>
13+
__attribute__((reqd_work_group_size(Args))) // expected-error {{expression contains unexpanded parameter pack 'Args'}}
14+
__global__ void TestTemplateVariadicArgs1(void) {}
15+
16+
template <int... Args>
17+
__attribute__((work_group_size_hint(Args))) // expected-error {{expression contains unexpanded parameter pack 'Args'}}
18+
__global__ void TestTemplateVariadicArgs2(void) {}
19+
20+
template <class a> // expected-note {{declared here}}
21+
__attribute__((reqd_work_group_size(a, 1, 1))) // expected-error {{'a' does not refer to a value}}
22+
__global__ void TestTemplateArgClass1(void) {}
23+
24+
template <class a> // expected-note {{declared here}}
25+
__attribute__((work_group_size_hint(a, 1, 1))) // expected-error {{'a' does not refer to a value}}
26+
__global__ void TestTemplateArgClass2(void) {}
27+
28+
constexpr int A = 512;
29+
30+
__attribute__((reqd_work_group_size(A, A, A)))
31+
__global__ void TestConstIntArg1(void) {}
32+
33+
__attribute__((work_group_size_hint(A, A, A)))
34+
__global__ void TestConstIntArg2(void) {}
35+
36+
int B = 512;
37+
__attribute__((reqd_work_group_size(B, 1, 1))) // expected-error {{attribute requires parameter 0 to be an integer constant}}
38+
__global__ void TestNonConstIntArg1(void) {}
39+
40+
__attribute__((work_group_size_hint(B, 1, 1))) // expected-error {{attribute requires parameter 0 to be an integer constant}}
41+
__global__ void TestNonConstIntArg2(void) {}
42+
43+
constexpr int C = -512;
44+
__attribute__((reqd_work_group_size(C, 1, 1))) // expected-error {{attribute requires a non-negative integral compile time constant expression}}
45+
__global__ void TestNegativeConstIntArg1(void) {}
46+
47+
__attribute__((work_group_size_hint(C, 1, 1))) // expected-error {{attribute requires a non-negative integral compile time constant expression}}
48+
__global__ void TestNegativeConstIntArg2(void) {}
49+
50+
51+
__attribute__((reqd_work_group_size(A, 0, 1))) // expected-error {{attribute must be greater than 0}}
52+
__global__ void TestZeroArg1(void) {}
53+
54+
__attribute__((work_group_size_hint(A, 0, 1))) // expected-error {{attribute must be greater than 0}}
55+
__global__ void TestZeroArg2(void) {}
56+
57+
58+

clang/test/SemaCUDA/spirv-attrs.cu

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,28 @@
33
// RUN: %clang_cc1 -triple spirv64 -aux-triple x86_64-unknown-linux-gnu \
44
// RUN: -fcuda-is-device -verify -fsyntax-only %s
55

6-
#include "Inputs/cuda.h"
6+
#define __global__ __attribute__((global))
77

88
__attribute__((reqd_work_group_size(128, 1, 1)))
99
__global__ void reqd_work_group_size_128_1_1() {}
1010

11+
template <unsigned a, unsigned b, unsigned c>
12+
__attribute__((reqd_work_group_size(a, b, c)))
13+
__global__ void reqd_work_group_size_a_b_c() {}
14+
15+
template <>
16+
__global__ void reqd_work_group_size_a_b_c<128,1,1>(void);
17+
1118
__attribute__((work_group_size_hint(2, 2, 2)))
1219
__global__ void work_group_size_hint_2_2_2() {}
1320

21+
template <unsigned a, unsigned b, unsigned c>
22+
__attribute__((work_group_size_hint(a, b, c)))
23+
__global__ void work_group_size_hint_a_b_c() {}
24+
25+
template <>
26+
__global__ void work_group_size_hint_a_b_c<128,1,1>(void);
27+
1428
__attribute__((vec_type_hint(int)))
1529
__global__ void vec_type_hint_int() {}
1630

0 commit comments

Comments
 (0)