@@ -38,6 +38,7 @@ static llvm::StringRef ExtractStringFromMDNodeOperand(const MDNode *N,
38
38
SYCLDeviceRequirements
39
39
llvm::computeDeviceRequirements (const module_split::ModuleDesc &MD) {
40
40
SYCLDeviceRequirements Reqs;
41
+ bool MultipleReqdWGSize = false ;
41
42
// Process all functions in the module
42
43
for (const Function &F : MD.getModule ()) {
43
44
if (auto *MDN = F.getMetadata (" sycl_used_aspects" )) {
@@ -70,6 +71,8 @@ llvm::computeDeviceRequirements(const module_split::ModuleDesc &MD) {
70
71
ExtractUnsignedIntegerFromMDNodeOperand (MDN, I));
71
72
if (!Reqs.ReqdWorkGroupSize .has_value ())
72
73
Reqs.ReqdWorkGroupSize = NewReqdWorkGroupSize;
74
+ if (Reqs.ReqdWorkGroupSize != NewReqdWorkGroupSize)
75
+ MultipleReqdWGSize = true ;
73
76
}
74
77
75
78
if (auto *MDN = F.getMetadata (" sycl_joint_matrix" )) {
@@ -105,6 +108,14 @@ llvm::computeDeviceRequirements(const module_split::ModuleDesc &MD) {
105
108
assert (*Reqs.SubGroupSize == static_cast <uint32_t >(MDValue));
106
109
}
107
110
}
111
+
112
+ // Usually, we would only expect one ReqdWGSize, as the module passed to
113
+ // this function would be split according to that. However, when splitting
114
+ // is disabled, this cannot be guaranteed. In this case, we reset the value,
115
+ // which makes so that no value is reqd_work_group_size data is attached in
116
+ // in the device image.
117
+ if (MultipleReqdWGSize)
118
+ Reqs.ReqdWorkGroupSize .reset ();
108
119
return Reqs;
109
120
}
110
121
0 commit comments