Skip to content

Commit a586269

Browse files
authored
[SYCL][Fusion][NoSTL] Simplify kernel attributes by making them less generic (#12376)
Kernel attributes were previously modeled as an `std::string` name, and an arbitrarily long `std::vector` of string values, however we currently don't need this genericity to represent `reqd_work_group_size` and `work_group_size_hint` attributes. In fact, an attribute kind enum and an `Indices` object (= 3 integers) suffice. _This PR is part of a series of changes to remove uses of STL classes in the kernel fusion interface to prevent ABI issues in the future._ Signed-off-by: Julian Oppermann <[email protected]>
1 parent a851f1a commit a586269

18 files changed

+134
-114
lines changed

sycl-fusion/common/include/Kernel.h

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <array>
1616
#include <cassert>
1717
#include <cstdint>
18+
#include <cstring>
1819
#include <string>
1920
#include <vector>
2021

@@ -108,18 +109,6 @@ struct SYCLKernelBinaryInfo {
108109
uint64_t BinarySize = 0;
109110
};
110111

111-
///
112-
/// Describe a SYCL/OpenCL kernel attribute by its name and values.
113-
struct SYCLKernelAttribute {
114-
using AttributeValueList = std::vector<std::string>;
115-
116-
SYCLKernelAttribute(std::string Name)
117-
: AttributeName{std::move(Name)}, Values{} {}
118-
119-
std::string AttributeName;
120-
AttributeValueList Values;
121-
};
122-
123112
///
124113
/// Encode usage of parameters for the actual kernel function.
125114
enum ArgUsage : uint8_t {
@@ -149,10 +138,6 @@ struct SYCLArgumentDescriptor {
149138
DynArray<ArgUsageUT> UsageMask;
150139
};
151140

152-
///
153-
/// List of SYCL/OpenCL kernel attributes.
154-
using AttributeList = std::vector<SYCLKernelAttribute>;
155-
156141
///
157142
/// Class to model a three-dimensional index.
158143
class Indices {
@@ -193,6 +178,48 @@ class Indices {
193178
size_t Values[Size];
194179
};
195180

181+
///
182+
/// Describe a SYCL/OpenCL kernel attribute by its kind and values.
183+
struct SYCLKernelAttribute {
184+
enum class AttrKind { Invalid, ReqdWorkGroupSize, WorkGroupSizeHint };
185+
186+
static constexpr auto ReqdWorkGroupSizeName = "reqd_work_group_size";
187+
static constexpr auto WorkGroupSizeHintName = "work_group_size_hint";
188+
189+
static AttrKind parseKind(const char *Name) {
190+
auto Kind = AttrKind::Invalid;
191+
if (std::strcmp(Name, ReqdWorkGroupSizeName) == 0) {
192+
Kind = AttrKind::ReqdWorkGroupSize;
193+
} else if (std::strcmp(Name, WorkGroupSizeHintName) == 0) {
194+
Kind = AttrKind::WorkGroupSizeHint;
195+
}
196+
return Kind;
197+
}
198+
199+
AttrKind Kind;
200+
Indices Values;
201+
202+
SYCLKernelAttribute() : Kind(AttrKind::Invalid) {}
203+
SYCLKernelAttribute(AttrKind Kind, const Indices &Values)
204+
: Kind(Kind), Values(Values) {}
205+
206+
const char *getName() const {
207+
assert(Kind != AttrKind::Invalid);
208+
switch (Kind) {
209+
case AttrKind::ReqdWorkGroupSize:
210+
return ReqdWorkGroupSizeName;
211+
case AttrKind::WorkGroupSizeHint:
212+
return WorkGroupSizeHintName;
213+
default:
214+
return "__invalid__";
215+
}
216+
}
217+
};
218+
219+
///
220+
/// List of SYCL/OpenCL kernel attributes.
221+
using SYCLAttributeList = DynArray<SYCLKernelAttribute>;
222+
196223
///
197224
/// Class to model SYCL nd_range
198225
class NDRange {
@@ -306,7 +333,7 @@ struct SYCLKernelInfo {
306333

307334
SYCLArgumentDescriptor Args;
308335

309-
AttributeList Attributes;
336+
SYCLAttributeList Attributes;
310337

311338
NDRange NDR;
312339

sycl-fusion/jit-compiler/lib/translation/KernelTranslation.cpp

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,19 @@ using namespace jit_compiler::translation;
2424
using namespace llvm;
2525

2626
///
27-
/// Get an attribute value consisting of NumValues scalar constant integers
28-
/// from the MDNode.
29-
static void getAttributeValues(std::vector<std::string> &Values, MDNode *MD) {
30-
for (const auto &MDOp : MD->operands()) {
31-
auto *ConstantMD = cast<ConstantAsMetadata>(MDOp);
32-
auto *ConstInt = cast<ConstantInt>(ConstantMD->getValue());
33-
Values.push_back(std::to_string(ConstInt->getZExtValue()));
34-
}
27+
/// Get an `Indices` object from the MDNode's three constant integer operands.
28+
static Indices getAttributeValues(MDNode *MD) {
29+
assert(MD->getNumOperands() == Indices::size());
30+
Indices Res;
31+
std::transform(MD->op_begin(), MD->op_end(), Res.begin(),
32+
[](const auto &MDOp) {
33+
auto *ConstantMD = cast<ConstantAsMetadata>(MDOp);
34+
auto *ConstInt = cast<ConstantInt>(ConstantMD->getValue());
35+
return ConstInt->getZExtValue();
36+
});
37+
return Res;
3538
}
3639

37-
// NOLINTNEXTLINE(readability-identifier-naming)
38-
static const char *REQD_WORK_GROUP_SIZE_ATTR = "reqd_work_group_size";
39-
// NOLINTNEXTLINE(readability-identifier-naming)
40-
static const char *WORK_GROUP_SIZE_HINT_ATTR = "work_group_size_hint";
41-
4240
///
4341
/// Restore kernel attributes for the kernel in Info from the metadata
4442
/// attached to its kernel function in the LLVM module Mod.
@@ -48,16 +46,20 @@ static const char *WORK_GROUP_SIZE_HINT_ATTR = "work_group_size_hint";
4846
static void restoreKernelAttributes(Module *Mod, SYCLKernelInfo &Info) {
4947
auto *KernelFunction = Mod->getFunction(Info.Name);
5048
assert(KernelFunction && "Kernel function not present in module");
51-
if (auto *MD = KernelFunction->getMetadata(REQD_WORK_GROUP_SIZE_ATTR)) {
52-
SYCLKernelAttribute ReqdAttr{REQD_WORK_GROUP_SIZE_ATTR};
53-
getAttributeValues(ReqdAttr.Values, MD);
54-
Info.Attributes.push_back(ReqdAttr);
49+
SmallVector<SYCLKernelAttribute, 2> Attrs;
50+
using AttrKind = SYCLKernelAttribute::AttrKind;
51+
if (auto *MD = KernelFunction->getMetadata(
52+
SYCLKernelAttribute::ReqdWorkGroupSizeName)) {
53+
Attrs.emplace_back(AttrKind::ReqdWorkGroupSize, getAttributeValues(MD));
5554
}
56-
if (auto *MD = KernelFunction->getMetadata(WORK_GROUP_SIZE_HINT_ATTR)) {
57-
SYCLKernelAttribute HintAttr{WORK_GROUP_SIZE_HINT_ATTR};
58-
getAttributeValues(HintAttr.Values, MD);
59-
Info.Attributes.push_back(HintAttr);
55+
if (auto *MD = KernelFunction->getMetadata(
56+
SYCLKernelAttribute::WorkGroupSizeHintName)) {
57+
Attrs.emplace_back(AttrKind::WorkGroupSizeHint, getAttributeValues(MD));
6058
}
59+
if (Attrs.empty())
60+
return;
61+
Info.Attributes = SYCLAttributeList{Attrs.size()};
62+
llvm::copy(Attrs, Info.Attributes.begin());
6163
}
6264

6365
llvm::Expected<std::unique_ptr<llvm::Module>>

sycl-fusion/passes/kernel-fusion/SYCLKernelFusion.cpp

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -444,10 +444,10 @@ Error SYCLKernelFusion::fuseKernel(
444444
assert(FusedParamKinds.size() == FusedArgUsageMask.size());
445445
jit_compiler::SYCLKernelInfo KI{FusedKernelName.str(),
446446
FusedParamKinds.size()};
447+
KI.Attributes = KernelAttributeList{FusedAttributes.size()};
447448
llvm::copy(FusedParamKinds, KI.Args.Kinds.begin());
448449
llvm::copy(FusedArgUsageMask, KI.Args.UsageMask.begin());
449-
KI.Attributes.insert(KI.Attributes.end(), FusedAttributes.begin(),
450-
FusedAttributes.end());
450+
llvm::copy(FusedAttributes, KI.Attributes.begin());
451451
ModInfo->addKernel(KI);
452452
}
453453
jit_compiler::SYCLKernelInfo &FusedKernelInfo =
@@ -700,16 +700,16 @@ void SYCLKernelFusion::attachKernelAttributeMD(
700700
// Attach kernel attribute information as metadata to a kernel function.
701701
for (jit_compiler::SYCLKernelAttribute &KernelAttr :
702702
FusedKernelInfo.Attributes) {
703-
if (KernelAttr.AttributeName == "reqd_work_group_size" ||
704-
KernelAttr.AttributeName == "work_group_size_hint") {
703+
if (KernelAttr.Kind == KernelAttrKind::ReqdWorkGroupSize ||
704+
KernelAttr.Kind == KernelAttrKind::WorkGroupSizeHint) {
705705
// 'reqd_work_group_size' and 'work_group_size_hint' get attached as
706706
// metadata with their three values as constant integer metadata.
707707
SmallVector<Metadata *, 3> MDValues;
708-
for (std::string &Val : KernelAttr.Values) {
708+
for (auto Val : KernelAttr.Values) {
709709
MDValues.push_back(ConstantAsMetadata::get(
710-
ConstantInt::get(Type::getInt32Ty(LLVMCtx), std::stoi(Val))));
710+
ConstantInt::get(Type::getInt32Ty(LLVMCtx), Val)));
711711
}
712-
attachFusedMetadata(FusedFunction, KernelAttr.AttributeName, MDValues);
712+
attachFusedMetadata(FusedFunction, KernelAttr.getName(), MDValues);
713713
}
714714
// The two kernel attributes above are currently the only attributes
715715
// attached as metadata, so we don't do anything for other attributes.
@@ -773,7 +773,7 @@ void SYCLKernelFusion::mergeKernelAttributes(
773773
// want to keep it anyways.
774774
for (const jit_compiler::SYCLKernelAttribute &OtherAttr : Other) {
775775
SYCLKernelFusion::KernelAttr *Attr =
776-
getAttribute(Attributes, OtherAttr.AttributeName);
776+
getAttribute(Attributes, OtherAttr.Kind);
777777
SYCLKernelFusion::AttrMergeResult MergeResult =
778778
mergeAttribute(Attr, OtherAttr);
779779
switch (MergeResult) {
@@ -786,7 +786,7 @@ void SYCLKernelFusion::mergeKernelAttributes(
786786
addAttribute(Attributes, OtherAttr);
787787
break;
788788
case AttrMergeResult::RemoveAttr:
789-
removeAttribute(Attributes, OtherAttr.AttributeName);
789+
removeAttribute(Attributes, OtherAttr.Kind);
790790
break;
791791
case AttrMergeResult::Error:
792792
llvm_unreachable("Failed to merge attribute");
@@ -798,14 +798,15 @@ void SYCLKernelFusion::mergeKernelAttributes(
798798
SYCLKernelFusion::AttrMergeResult
799799
SYCLKernelFusion::mergeAttribute(KernelAttr *Attr,
800800
const KernelAttr &Other) const {
801-
if (Other.AttributeName == "reqd_work_group_size") {
801+
switch (Other.Kind) {
802+
case KernelAttrKind::ReqdWorkGroupSize:
802803
return mergeReqdWorkgroupSize(Attr, Other);
803-
}
804-
if (Other.AttributeName == "work_group_size_hint") {
804+
case KernelAttrKind::WorkGroupSizeHint:
805805
return mergeWorkgroupSizeHint(Attr, Other);
806+
default:
807+
// Unknown attribute name, return an error.
808+
return SYCLKernelFusion::AttrMergeResult::Error;
806809
}
807-
// Unknown attribute name, return an error.
808-
return SYCLKernelFusion::AttrMergeResult::Error;
809810
}
810811

811812
SYCLKernelFusion::AttrMergeResult
@@ -816,11 +817,9 @@ SYCLKernelFusion::mergeReqdWorkgroupSize(KernelAttr *Attr,
816817
// new one
817818
return SYCLKernelFusion::AttrMergeResult::AddAttr;
818819
}
819-
for (size_t I = 0; I < 3; ++I) {
820-
if (getAttrValueAsInt(*Attr, I) != getAttrValueAsInt(Other, I)) {
821-
// Two different required work-group sizes, causes an error.
822-
return SYCLKernelFusion::AttrMergeResult::Error;
823-
}
820+
if (Attr->Values != Other.Values) {
821+
// Two different required work-group sizes, causes an error.
822+
return SYCLKernelFusion::AttrMergeResult::Error;
824823
}
825824
// The required workgroup sizes are identical, keep it.
826825
return SYCLKernelFusion::AttrMergeResult::KeepAttr;
@@ -834,20 +833,18 @@ SYCLKernelFusion::mergeWorkgroupSizeHint(KernelAttr *Attr,
834833
// the new one
835834
return SYCLKernelFusion::AttrMergeResult::AddAttr;
836835
}
837-
for (size_t I = 0; I < 3; ++I) {
838-
if (getAttrValueAsInt(*Attr, I) != getAttrValueAsInt(Other, I)) {
839-
// Two different hints, remove the hint altogether.
840-
return SYCLKernelFusion::AttrMergeResult::RemoveAttr;
841-
}
836+
if (Attr->Values != Other.Values) {
837+
// Two different hints, remove the hint altogether.
838+
return SYCLKernelFusion::AttrMergeResult::RemoveAttr;
842839
}
843840
// The given hint is identical, keep it.
844841
return SYCLKernelFusion::AttrMergeResult::KeepAttr;
845842
}
846843

847844
SYCLKernelFusion::KernelAttr *
848845
SYCLKernelFusion::getAttribute(MutableAttributeList &Attributes,
849-
StringRef AttrName) const {
850-
auto *It = findAttribute(Attributes, AttrName);
846+
KernelAttrKind AttrKind) const {
847+
auto *It = findAttribute(Attributes, AttrKind);
851848
if (It != Attributes.end()) {
852849
return &*It;
853850
}
@@ -860,25 +857,17 @@ void SYCLKernelFusion::addAttribute(MutableAttributeList &Attributes,
860857
}
861858

862859
void SYCLKernelFusion::removeAttribute(MutableAttributeList &Attributes,
863-
StringRef AttrName) const {
864-
auto *It = findAttribute(Attributes, AttrName);
860+
KernelAttrKind AttrKind) const {
861+
auto *It = findAttribute(Attributes, AttrKind);
865862
if (It != Attributes.end()) {
866863
Attributes.erase(It);
867864
}
868865
}
869866

870867
SYCLKernelFusion::MutableAttributeList::iterator
871868
SYCLKernelFusion::findAttribute(MutableAttributeList &Attributes,
872-
StringRef AttrName) const {
869+
KernelAttrKind AttrKind) const {
873870
return llvm::find_if(Attributes, [=](SYCLKernelFusion::KernelAttr &Attr) {
874-
return Attr.AttributeName == AttrName.str();
871+
return Attr.Kind == AttrKind;
875872
});
876873
}
877-
878-
unsigned SYCLKernelFusion::getAttrValueAsInt(const KernelAttr &Attr,
879-
size_t Idx) const {
880-
assert(Idx < Attr.Values.size());
881-
unsigned Result = 0;
882-
StringRef(Attr.Values[Idx]).getAsInteger(0, Result);
883-
return Result;
884-
}

sycl-fusion/passes/kernel-fusion/SYCLKernelFusion.h

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,10 @@ class SYCLKernelFusion : public llvm::PassInfoMixin<SYCLKernelFusion> {
144144
jit_compiler::SYCLArgumentDescriptor &InputDef,
145145
const llvm::ArrayRef<bool> ParamUseMask) const;
146146

147-
using KernelAttributeList = jit_compiler::AttributeList;
147+
using KernelAttributeList = jit_compiler::SYCLAttributeList;
148148

149149
using KernelAttr = jit_compiler::SYCLKernelAttribute;
150+
using KernelAttrKind = jit_compiler::SYCLKernelAttribute::AttrKind;
150151

151152
///
152153
/// Indicates the result of merging two attributes of the same kind.
@@ -187,30 +188,26 @@ class SYCLKernelFusion : public llvm::PassInfoMixin<SYCLKernelFusion> {
187188
const KernelAttr &Other) const;
188189

189190
///
190-
/// Get the attribute with the specified name from the list or return nullptr
191+
/// Get the attribute with the specified kind from the list or return nullptr
191192
/// in case no such attribute is present.
192193
KernelAttr *getAttribute(MutableAttributeList &Attributes,
193-
llvm::StringRef AttrName) const;
194+
KernelAttrKind AttrKind) const;
194195

195196
///
196197
/// Add the attribute to the list.
197198
void addAttribute(MutableAttributeList &Attributes,
198199
const KernelAttr &Attr) const;
199200

200201
///
201-
/// Remove the attribute with the specified name from the list, if present.
202+
/// Remove the attribute with the specified kind from the list, if present.
202203
void removeAttribute(MutableAttributeList &Attributes,
203-
llvm::StringRef AttrName) const;
204+
KernelAttrKind AttrKind) const;
204205

205206
///
206-
/// Find the attribute with the specified name in the list, or return the
207+
/// Find the attribute with the specified kind in the list, or return the
207208
/// end() iterator if no such attribute is present.
208209
MutableAttributeList::iterator findAttribute(MutableAttributeList &Attributes,
209-
llvm::StringRef AttrName) const;
210-
211-
///
212-
/// Retrieve the attribute value at the given index as unsigned integer.
213-
unsigned getAttrValueAsInt(const KernelAttr &Attr, size_t Idx) const;
210+
KernelAttrKind AttrKind) const;
214211
};
215212

216213
#endif // SYCL_FUSION_PASSES_SYCLKERNELFUSION_H

sycl-fusion/passes/kernel-info/SYCLKernelInfo.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,22 +84,25 @@ void SYCLModuleInfoAnalysis::loadModuleInfoFromMetadata(Module &M) {
8484
KernelInfo.Args.UsageMask.begin(), getUInt<ArgUsageUT>);
8585

8686
// Operands 3..n: Attributes
87-
for (; It != End; ++It) {
88-
auto *AIMD = cast<MDNode>(*It);
89-
assert(AIMD->getNumOperands() > 1);
87+
KernelInfo.Attributes = jit_compiler::SYCLAttributeList{
88+
static_cast<size_t>(std::distance(It, End))};
89+
std::transform(It, End, KernelInfo.Attributes.begin(), [](const auto &Op) {
90+
auto *AIMD = cast<MDNode>(Op);
91+
assert(AIMD->getNumOperands() == 4);
9092
const auto *AttrIt = AIMD->op_begin(), *AttrEnd = AIMD->op_end();
9193

9294
// Operand 0: Attribute name
9395
auto Name = cast<MDString>(*AttrIt)->getString().str();
96+
auto Kind = SYCLKernelAttribute::parseKind(Name.c_str());
97+
assert(Kind != SYCLKernelAttribute::AttrKind::Invalid);
9498
++AttrIt;
9599

96-
// Operands 1..m: String values
97-
auto &KernelAttr = KernelInfo.Attributes.emplace_back(std::move(Name));
98-
for (; AttrIt != AttrEnd; ++AttrIt) {
99-
auto Value = cast<MDString>(*AttrIt)->getString().str();
100-
KernelAttr.Values.emplace_back(std::move(Value));
101-
}
102-
}
100+
// Operands 1..3: Values
101+
Indices Values;
102+
std::transform(AttrIt, AttrEnd, Values.begin(), getUInt<size_t>);
103+
104+
return SYCLKernelAttribute{Kind, Values};
105+
});
103106

104107
ModuleInfo->addKernel(KernelInfo);
105108
}
@@ -156,7 +159,7 @@ PreservedAnalyses SYCLModuleInfoPrinter::run(Module &Mod,
156159

157160
Out.indent(Indent) << "Attributes:\n";
158161
for (const auto &AttrInfo : KernelInfo.Attributes) {
159-
Out.indent(Indent * 2) << AttrInfo.AttributeName << ':';
162+
Out.indent(Indent * 2) << AttrInfo.getName() << ':';
160163
Out.PadToColumn(Pad);
161164
llvm::interleaveComma(AttrInfo.Values, Out);
162165
Out << '\n';

0 commit comments

Comments
 (0)