@@ -130,23 +130,125 @@ Function *JointMatrixFuncsResolutionPass::getEntryFunction(Function *F)
130
130
return nullptr ;
131
131
}
132
132
133
- void JointMatrixFuncsResolutionPass::ResolveSIMDSize (Function *F) {
134
- if (m_SIMDSize != 0 ) return ;
133
+ int32_t JointMatrixFuncsResolutionPass::DetermineForcedSIMDSize ()
134
+ {
135
+ int32_t forcedSIMDSize = m_Ctx->getModuleMetaData ()->csInfo .forcedSIMDSize ;
136
+
137
+ if (IGC_IS_FLAG_ENABLED (EnableOCLSIMD32) && IGC_IS_FLAG_DISABLED (ForceCSSIMD16) && (forcedSIMDSize == 32 || IGC_IS_FLAG_ENABLED (ForceCSSIMD32)))
138
+ {
139
+ if (forcedSIMDSize == 0 )
140
+ m_Ctx->getModuleMetaData ()->csInfo .forcedSIMDSize = 32 ;
141
+ return 32 ;
142
+ }
143
+
144
+ if (IGC_IS_FLAG_ENABLED (EnableOCLSIMD16) && IGC_IS_FLAG_DISABLED (ForceCSSIMD32) && (forcedSIMDSize == 16 || IGC_IS_FLAG_ENABLED (ForceCSSIMD16)))
145
+ {
146
+ if (forcedSIMDSize == 0 )
147
+ m_Ctx->getModuleMetaData ()->csInfo .forcedSIMDSize = 16 ;
148
+ return 16 ;
149
+ }
150
+
151
+ return forcedSIMDSize;
152
+ }
153
+
154
+ int32_t JointMatrixFuncsResolutionPass::DefineKernelSIMDSize ()
155
+ {
156
+ if (m_Ctx->platform .hasExecSize16DPAS ())
157
+ {
158
+ if (IGC_IS_FLAG_ENABLED (EnableOCLSIMD16) && IGC_IS_FLAG_DISABLED (ForceCSSIMD32))
159
+ return 16 ;
160
+ if (IGC_IS_FLAG_ENABLED (EnableOCLSIMD32) && IGC_IS_FLAG_DISABLED (ForceCSSIMD16))
161
+ return 32 ;
162
+ std::string msg = " Sub group sizes supported by Joint Matrix for this platform are disabled by flags or non-supported sub group size forced." ;
163
+ m_Ctx->EmitError (msg.c_str (), nullptr );
164
+ return 0 ;
165
+ }
166
+ if (IGC_IS_FLAG_ENABLED (EnableOCLSIMD32) && IGC_IS_FLAG_ENABLED (ForceCSSIMD32))
167
+ {
168
+ std::string msg = " Sub group size 32 forced by flags but not supported by Joint Matrix on this platform." ;
169
+ m_Ctx->EmitError (msg.c_str (), nullptr );
170
+ return 0 ;
171
+ }
172
+ if (IGC_IS_FLAG_ENABLED (EnableOCLSIMD16) && IGC_IS_FLAG_ENABLED (ForceCSSIMD16))
173
+ {
174
+ std::string msg = " Sub group size 16 forced by flags but not supported by Joint Matrix on this platform." ;
175
+ m_Ctx->EmitError (msg.c_str (), nullptr );
176
+ return 0 ;
177
+ }
178
+ return 8 ;
179
+ }
180
+
181
+ bool JointMatrixFuncsResolutionPass::IsSIMDSizeValid (int32_t simdSize)
182
+ {
183
+ return ((m_Ctx->platform .hasExecSize16DPAS () && (simdSize == 16 || simdSize == 32 )) ||
184
+ (!m_Ctx->platform .hasExecSize16DPAS () && simdSize == 8 ));
185
+ }
135
186
187
+ void JointMatrixFuncsResolutionPass::ForceKernelSIMDSize (Function *F, int32_t forcedSIMDSize)
188
+ {
136
189
Function *entryFunction = getEntryFunction (F);
137
- if (entryFunction) {
190
+ if (entryFunction) // if can find entry function
191
+ {
192
+ IGCMD::FunctionInfoMetaDataHandle funcInfoMD = m_mdUtils->getFunctionsInfoItem (entryFunction);
193
+ IGCMD::SubGroupSizeMetaDataHandle subGroupSize = funcInfoMD->getSubGroupSize ();
194
+ subGroupSize->setSIMD_size (forcedSIMDSize);
195
+ }
196
+ }
197
+
198
+ void JointMatrixFuncsResolutionPass::ResolveSIMDSize (Function *F)
199
+ {
200
+ if (m_SIMDSize != 0 )
201
+ return ;
202
+
203
+ int32_t forcedSIMDSize = DetermineForcedSIMDSize ();
204
+ if (forcedSIMDSize != 0 )
205
+ {
206
+ if (IsSIMDSizeValid (forcedSIMDSize))
207
+ {
208
+ m_SIMDSize = forcedSIMDSize;
209
+ ForceKernelSIMDSize (F, m_SIMDSize);
210
+ return ;
211
+ }
212
+ // if forced and not ok for platform exit with error
213
+ std::string msg = " Sub group size " + std::to_string (forcedSIMDSize) + " is forced by flags but not supported by Joint Matrix on this platform." ;
214
+ m_Ctx->EmitError (msg.c_str (), nullptr );
215
+ return ;
216
+ }
217
+
218
+ // if not forced by driver of flags, check on entry function level
219
+ Function *entryFunction = getEntryFunction (F);
220
+ if (entryFunction) // if can find entry function
221
+ {
138
222
IGCMD::FunctionInfoMetaDataHandle funcInfoMD = m_mdUtils->getFunctionsInfoItem (entryFunction);
139
223
IGCMD::SubGroupSizeMetaDataHandle subGroupSize = funcInfoMD->getSubGroupSize ();
140
224
if (subGroupSize->hasValue ())
141
- m_SIMDSize = subGroupSize->getSIMD_size ();
225
+ {
226
+ int32_t kernelSIMDSize = subGroupSize->getSIMD_size ();
227
+ if (kernelSIMDSize != 0 )
228
+ {
229
+ if (IsSIMDSizeValid (kernelSIMDSize))
230
+ {
231
+ m_SIMDSize = kernelSIMDSize;
232
+ return ;
233
+ }
234
+ // if set on entry function level and not ok for this platform exit with error
235
+ std::string msg = " Sub group size " + std::to_string (kernelSIMDSize) + " is forced by attribute but not supported by Joint Matrix on this platform." ;
236
+ m_Ctx->EmitError (msg.c_str (), nullptr );
237
+ return ;
238
+ }
239
+ }
240
+ // if not set on entry function level, define ourselves
241
+ m_SIMDSize = DefineKernelSIMDSize ();
242
+ // and set to entry level function
243
+ subGroupSize->setSIMD_size (m_SIMDSize);
244
+ return ;
142
245
}
143
246
144
247
// If no entry function found (it means that we could not detect that current function is called
145
- // from any kernel) or kernel function doesn't have sub group size requirement set, we anyway
146
- // will resolve function, just in case, using default sub group size.
147
- // In theory it may cause resolution conflicts if sub group sizes are mixed.
148
- if (m_SIMDSize == 0 )
149
- m_SIMDSize = m_Ctx->platform .hasExecSize16DPAS () ? 16 : 8 ;
248
+ // from any kernel), we anyway will resolve function, just in case, using default sub group size.
249
+ m_SIMDSize = DefineKernelSIMDSize ();
250
+ // Force SIMD size if not set, as Joint Matrix need it to define numer of elements in WI
251
+ m_Ctx->getModuleMetaData ()->csInfo .forcedSIMDSize = (unsigned char )m_SIMDSize;
150
252
}
151
253
152
254
bool JointMatrixFuncsResolutionPass::runOnFunction (Function& F)
@@ -178,7 +280,8 @@ bool JointMatrixFuncsResolutionPass::runOnFunction(Function& F)
178
280
return !ResolvedValues.empty ();
179
281
}
180
282
181
- static const char *CommonBIPrefix = " __builtin_spirv_" ;
283
+ static const char *JointMatrixBIPrefix = " __builtin_spirv_OpJointMatrix" ;
284
+ static const char *JointMatrixBISuffix = " JointMatrixINTEL_" ;
182
285
static const char *JointMatrixLoadPrefx = " JointMatrixLoadINTEL" ;
183
286
static const char *JointMatrixStorePrefx = " JointMatrixStoreINTEL" ;
184
287
static const char *JointMatrixMadPrefx = " JointMatrixMadINTEL" ;
@@ -1661,7 +1764,7 @@ void JointMatrixFuncsResolutionPass::visitCallInst(CallInst& CI)
1661
1764
* future when returning and passing matrices by argument is
1662
1765
* supported also basic block terminators should be used as
1663
1766
* transformation starting point */
1664
- if (funcName.startswith (CommonBIPrefix )) {
1767
+ if (funcName.startswith (JointMatrixBIPrefix) || funcName. contains (JointMatrixBISuffix )) {
1665
1768
ResolveSIMDSize (CI.getParent ()->getParent ());
1666
1769
ResolveCall (&CI);
1667
1770
return ;
0 commit comments