Skip to content

Commit ef62cad

Browse files
authored
[SYCL][NVPTX] Split max_work_group_size into 3 NVVM annotations (#14420)
NVVM IR supports separated maxntidx, maxntidy, and maxntidz annotations. The backend will print them individually as three dimensions. This better preserves programmer intent than prematurely flattening them together. Note that the semantics are in fact identical; the CUDA implementation internally multiplies all dimensions together and only guarantees that the total is never exceeded, but not that any individual dimension is not exceeded. Thus 64,1,1 is identical to 4,4,4. We try and preserve a logical mapping of dimensions by index flipping between SYCL (z,y,x) and NVVM (x,y,z) in CUDA terminology despite, as mentioned above, it being largely irrelevant. Also this patch simplifies the attribute's getter functions as all dimensions are mandatory, and the getters seemed copied from the reqd_work_group_size attribute where some are optional. We could probably improve the code further by making the operands "unsigned" and not "Expr", and renaming them from X,Y,Z to Dim{0,1,2} as per the SYCL spec. This has been left for future work, however, as there's a non-trivial amount of code that expects to be able to treat the max_work_group_size and reqd_work_group_size attributes identically through templates and identical helper methods.
1 parent 7cb3107 commit ef62cad

File tree

4 files changed

+55
-41
lines changed

4 files changed

+55
-41
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1705,20 +1705,14 @@ def SYCLIntelMaxWorkGroupSize : InheritableAttr {
17051705
let LangOpts = [SYCLIsDevice, SilentlyIgnoreSYCLIsHost];
17061706
let Subjects = SubjectList<[Function], ErrorDiag>;
17071707
let AdditionalMembers = [{
1708-
std::optional<llvm::APSInt> getXDimVal() const {
1709-
if (const auto *CE = dyn_cast<ConstantExpr>(getXDim()))
1710-
return CE->getResultAsAPSInt();
1711-
return std::nullopt;
1708+
unsigned getXDimVal() const {
1709+
return cast<ConstantExpr>(getXDim())->getResultAsAPSInt().getExtValue();
17121710
}
1713-
std::optional<llvm::APSInt> getYDimVal() const {
1714-
if (const auto *CE = dyn_cast<ConstantExpr>(getYDim()))
1715-
return CE->getResultAsAPSInt();
1716-
return std::nullopt;
1711+
unsigned getYDimVal() const {
1712+
return cast<ConstantExpr>(getYDim())->getResultAsAPSInt().getExtValue();
17171713
}
1718-
std::optional<llvm::APSInt> getZDimVal() const {
1719-
if (const auto *CE = dyn_cast<ConstantExpr>(getZDim()))
1720-
return CE->getResultAsAPSInt();
1721-
return std::nullopt;
1714+
unsigned getZDimVal() const {
1715+
return cast<ConstantExpr>(getZDim())->getResultAsAPSInt().getExtValue();
17221716
}
17231717
}];
17241718
let Documentation = [SYCLIntelMaxWorkGroupSizeAttrDocs];

clang/lib/CodeGen/CodeGenFunction.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -826,9 +826,9 @@ void CodeGenFunction::EmitKernelMetadata(const FunctionDecl *FD,
826826
// Attributes arguments (first and third) are reversed on SYCLDevice.
827827
if (getLangOpts().SYCLIsDevice) {
828828
llvm::Metadata *AttrMDArgs[] = {
829-
llvm::ConstantAsMetadata::get(Builder.getInt(*A->getZDimVal())),
830-
llvm::ConstantAsMetadata::get(Builder.getInt(*A->getYDimVal())),
831-
llvm::ConstantAsMetadata::get(Builder.getInt(*A->getXDimVal()))};
829+
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getZDimVal())),
830+
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getYDimVal())),
831+
llvm::ConstantAsMetadata::get(Builder.getInt32(A->getXDimVal()))};
832832
Fn->setMetadata("max_work_group_size",
833833
llvm::MDNode::get(Context, AttrMDArgs));
834834
}

clang/lib/CodeGen/Targets/NVPTX.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,13 +252,13 @@ void NVPTXTargetCodeGenInfo::setTargetAttributes(
252252
bool HasMaxWorkGroupSize = false;
253253
bool HasMinWorkGroupPerCU = false;
254254
if (const auto *MWGS = FD->getAttr<SYCLIntelMaxWorkGroupSizeAttr>()) {
255-
auto MaxThreads = (*MWGS->getZDimVal()).getExtValue() *
256-
(*MWGS->getYDimVal()).getExtValue() *
257-
(*MWGS->getXDimVal()).getExtValue();
258-
if (MaxThreads > 0) {
259-
addNVVMMetadata(F, "maxntidx", MaxThreads);
260-
HasMaxWorkGroupSize = true;
261-
}
255+
HasMaxWorkGroupSize = true;
256+
// We must index-flip between SYCL's notation, X,Y,Z (aka dim0,dim1,dim2)
257+
// with the fastest-moving dimension rightmost, to CUDA's, where X is the
258+
// fastest-moving dimension.
259+
addNVVMMetadata(F, "maxntidx", MWGS->getZDimVal());
260+
addNVVMMetadata(F, "maxntidy", MWGS->getYDimVal());
261+
addNVVMMetadata(F, "maxntidz", MWGS->getXDimVal());
262262
}
263263

264264
auto attrValue = [&](Expr *E) {

clang/test/CodeGenSYCL/launch_bounds_nvptx.cpp

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// compute unit and maximum work groups per multi-processor attributes, that
55
// correspond to CUDA's launch bounds. Expect max_work_group_size,
66
// min_work_groups_per_cu and max_work_groups_per_mp that are mapped to
7-
// maxntidx, minctasm, and maxclusterrank NVVM annotations respectively.
7+
// maxntid[xyz], minctasm, and maxclusterrank NVVM annotations respectively.
88

99
#include "sycl.hpp"
1010

@@ -13,24 +13,24 @@ queue q;
1313

1414
class Foo {
1515
public:
16-
[[intel::max_work_group_size(8, 8, 8), intel::min_work_groups_per_cu(2),
16+
[[intel::max_work_group_size(2, 4, 8), intel::min_work_groups_per_cu(2),
1717
intel::max_work_groups_per_mp(4)]] void
1818
operator()() const {}
1919
};
2020

2121
template <int N> class Functor {
2222
public:
23-
[[intel::max_work_group_size(N, 8, 8), intel::min_work_groups_per_cu(N),
23+
[[intel::max_work_group_size(N, 4, 8), intel::min_work_groups_per_cu(N),
2424
intel::max_work_groups_per_mp(N)]] void
2525
operator()() const {}
2626
};
2727

2828
template <int N>
29-
[[intel::max_work_group_size(N, 8, 8), intel::min_work_groups_per_cu(N),
29+
[[intel::max_work_group_size(N, 4, 8), intel::min_work_groups_per_cu(N),
3030
intel::max_work_groups_per_mp(N)]] void
3131
zoo() {}
3232

33-
[[intel::max_work_group_size(8, 8, 8), intel::min_work_groups_per_cu(2),
33+
[[intel::max_work_group_size(2, 4, 8), intel::min_work_groups_per_cu(2),
3434
intel::max_work_groups_per_mp(4)]] void
3535
bar() {}
3636

@@ -42,7 +42,7 @@ int main() {
4242

4343
// Test attribute is applied on lambda.
4444
h.single_task<class kernel_name2>(
45-
[] [[intel::max_work_group_size(8, 8, 8),
45+
[] [[intel::max_work_group_size(2, 4, 8),
4646
intel::min_work_groups_per_cu(2),
4747
intel::max_work_groups_per_mp(4)]] () {});
4848

@@ -65,41 +65,61 @@ int main() {
6565
// CHECK: define dso_local void @{{.*}}kernel_name4() #0 {{.*}} !min_work_groups_per_cu ![[MWGPC:[0-9]+]] !max_work_groups_per_mp ![[MWGPM:[0-9]+]] !max_work_group_size ![[MWGS:[0-9]+]]
6666
// CHECK: define dso_local void @{{.*}}kernel_name5() #0 {{.*}} !min_work_groups_per_cu ![[MWGPC_MWGPM_2:[0-9]+]] !max_work_groups_per_mp ![[MWGPC_MWGPM_2]] !max_work_group_size ![[MWGS_3:[0-9]+]]
6767

68-
// CHECK: {{.*}}@{{.*}}kernel_name1, !"maxntidx", i32 512}
68+
// CHECK: {{.*}}@{{.*}}kernel_name1, !"maxntidx", i32 8}
69+
// CHECK: {{.*}}@{{.*}}kernel_name1, !"maxntidy", i32 4}
70+
// CHECK: {{.*}}@{{.*}}kernel_name1, !"maxntidz", i32 2}
6971
// CHECK: {{.*}}@{{.*}}kernel_name1, !"minctasm", i32 2}
7072
// CHECK: {{.*}}@{{.*}}kernel_name1, !"maxclusterrank", i32 4}
71-
// CHECK: {{.*}}@{{.*}}Foo{{.*}}, !"maxntidx", i32 512}
73+
// CHECK: {{.*}}@{{.*}}Foo{{.*}}, !"maxntidx", i32 8}
74+
// CHECK: {{.*}}@{{.*}}Foo{{.*}}, !"maxntidy", i32 4}
75+
// CHECK: {{.*}}@{{.*}}Foo{{.*}}, !"maxntidz", i32 2}
7276
// CHECK: {{.*}}@{{.*}}Foo{{.*}}, !"minctasm", i32 2}
7377
// CHECK: {{.*}}@{{.*}}Foo{{.*}}, !"maxclusterrank", i32 4}
74-
// CHECK: {{.*}}@{{.*}}kernel_name2, !"maxntidx", i32 512}
78+
// CHECK: {{.*}}@{{.*}}kernel_name2, !"maxntidx", i32 8}
79+
// CHECK: {{.*}}@{{.*}}kernel_name2, !"maxntidy", i32 4}
80+
// CHECK: {{.*}}@{{.*}}kernel_name2, !"maxntidz", i32 2}
7581
// CHECK: {{.*}}@{{.*}}kernel_name2, !"minctasm", i32 2}
7682
// CHECK: {{.*}}@{{.*}}kernel_name2, !"maxclusterrank", i32 4}
77-
// CHECK: {{.*}}@{{.*}}main{{.*}}, !"maxntidx", i32 512}
83+
// CHECK: {{.*}}@{{.*}}main{{.*}}, !"maxntidx", i32 8}
84+
// CHECK: {{.*}}@{{.*}}main{{.*}}, !"maxntidy", i32 4}
85+
// CHECK: {{.*}}@{{.*}}main{{.*}}, !"maxntidz", i32 2}
7886
// CHECK: {{.*}}@{{.*}}main{{.*}}, !"minctasm", i32 2}
7987
// CHECK: {{.*}}@{{.*}}main{{.*}}, !"maxclusterrank", i32 4}
80-
// CHECK: {{.*}}@{{.*}}kernel_name3, !"maxntidx", i32 384}
88+
// CHECK: {{.*}}@{{.*}}kernel_name3, !"maxntidx", i32 8}
89+
// CHECK: {{.*}}@{{.*}}kernel_name3, !"maxntidy", i32 4}
90+
// CHECK: {{.*}}@{{.*}}kernel_name3, !"maxntidz", i32 6}
8191
// CHECK: {{.*}}@{{.*}}kernel_name3, !"minctasm", i32 6}
8292
// CHECK: {{.*}}@{{.*}}kernel_name3, !"maxclusterrank", i32 6}
83-
// CHECK: {{.*}}@{{.*}}Functor{{.*}}, !"maxntidx", i32 384}
93+
// CHECK: {{.*}}@{{.*}}Functor{{.*}}, !"maxntidx", i32 8}
94+
// CHECK: {{.*}}@{{.*}}Functor{{.*}}, !"maxntidy", i32 4}
95+
// CHECK: {{.*}}@{{.*}}Functor{{.*}}, !"maxntidz", i32 6}
8496
// CHECK: {{.*}}@{{.*}}Functor{{.*}}, !"minctasm", i32 6}
8597
// CHECK: {{.*}}@{{.*}}Functor{{.*}}, !"maxclusterrank", i32 6}
86-
// CHECK: {{.*}}@{{.*}}kernel_name4, !"maxntidx", i32 512}
98+
// CHECK: {{.*}}@{{.*}}kernel_name4, !"maxntidx", i32 8}
99+
// CHECK: {{.*}}@{{.*}}kernel_name4, !"maxntidy", i32 4}
100+
// CHECK: {{.*}}@{{.*}}kernel_name4, !"maxntidz", i32 2}
87101
// CHECK: {{.*}}@{{.*}}kernel_name4, !"minctasm", i32 2}
88102
// CHECK: {{.*}}@{{.*}}kernel_name4, !"maxclusterrank", i32 4}
89-
// CHECK: {{.*}}@{{.*}}bar{{.*}}, !"maxntidx", i32 512}
103+
// CHECK: {{.*}}@{{.*}}bar{{.*}}, !"maxntidx", i32 8}
104+
// CHECK: {{.*}}@{{.*}}bar{{.*}}, !"maxntidy", i32 4}
105+
// CHECK: {{.*}}@{{.*}}bar{{.*}}, !"maxntidz", i32 2}
90106
// CHECK: {{.*}}@{{.*}}bar{{.*}}, !"minctasm", i32 2}
91107
// CHECK: {{.*}}@{{.*}}bar{{.*}}, !"maxclusterrank", i32 4}
92-
// CHECK: {{.*}}@{{.*}}kernel_name5, !"maxntidx", i32 1024}
108+
// CHECK: {{.*}}@{{.*}}kernel_name5, !"maxntidx", i32 8}
109+
// CHECK: {{.*}}@{{.*}}kernel_name5, !"maxntidy", i32 4}
110+
// CHECK: {{.*}}@{{.*}}kernel_name5, !"maxntidz", i32 16}
93111
// CHECK: {{.*}}@{{.*}}kernel_name5, !"minctasm", i32 16}
94112
// CHECK: {{.*}}@{{.*}}kernel_name5, !"maxclusterrank", i32 16}
95-
// CHECK: {{.*}}@{{.*}}zoo{{.*}}, !"maxntidx", i32 1024}
113+
// CHECK: {{.*}}@{{.*}}zoo{{.*}}, !"maxntidx", i32 8}
114+
// CHECK: {{.*}}@{{.*}}zoo{{.*}}, !"maxntidy", i32 4}
115+
// CHECK: {{.*}}@{{.*}}zoo{{.*}}, !"maxntidz", i32 16}
96116
// CHECK: {{.*}}@{{.*}}zoo{{.*}}, !"minctasm", i32 16}
97117
// CHECK: {{.*}}@{{.*}}zoo{{.*}}, !"maxclusterrank", i32 16}
98118

99119
// CHECK: ![[MWGPC]] = !{i32 2}
100120
// CHECK: ![[MWGPM]] = !{i32 4}
101-
// CHECK: ![[MWGS]] = !{i32 8, i32 8, i32 8}
121+
// CHECK: ![[MWGS]] = !{i32 8, i32 4, i32 2}
102122
// CHECK: ![[MWGPC_MWGPM]] = !{i32 6}
103-
// CHECK: ![[MWGS_2]] = !{i32 8, i32 8, i32 6}
123+
// CHECK: ![[MWGS_2]] = !{i32 8, i32 4, i32 6}
104124
// CHECK: ![[MWGPC_MWGPM_2]] = !{i32 16}
105-
// CHECK: ![[MWGS_3]] = !{i32 8, i32 8, i32 16}
125+
// CHECK: ![[MWGS_3]] = !{i32 8, i32 4, i32 16}

0 commit comments

Comments
 (0)