Skip to content

Commit c8332b3

Browse files
YuriPlyakhinigcbot
authored andcommitted
Force SIMD Size for Joint Matrix code
Since Joint Matrix require to know SIMD size to resolve types and calls, force setting of SIMD size, if it is not set in the code.
1 parent b6bd485 commit c8332b3

File tree

4 files changed

+122
-13
lines changed

4 files changed

+122
-13
lines changed

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.cpp

Lines changed: 114 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,23 +130,125 @@ Function *JointMatrixFuncsResolutionPass::getEntryFunction(Function *F)
130130
return nullptr;
131131
}
132132

133-
void JointMatrixFuncsResolutionPass::ResolveSIMDSize(Function *F) {
134-
if (m_SIMDSize != 0) return;
133+
int32_t JointMatrixFuncsResolutionPass::DetermineForcedSIMDSize()
134+
{
135+
int32_t forcedSIMDSize = m_Ctx->getModuleMetaData()->csInfo.forcedSIMDSize;
136+
137+
if (IGC_IS_FLAG_ENABLED(EnableOCLSIMD32) && IGC_IS_FLAG_DISABLED(ForceCSSIMD16) && (forcedSIMDSize == 32 || IGC_IS_FLAG_ENABLED(ForceCSSIMD32)))
138+
{
139+
if (forcedSIMDSize == 0)
140+
m_Ctx->getModuleMetaData()->csInfo.forcedSIMDSize = 32;
141+
return 32;
142+
}
143+
144+
if (IGC_IS_FLAG_ENABLED(EnableOCLSIMD16) && IGC_IS_FLAG_DISABLED(ForceCSSIMD32) && (forcedSIMDSize == 16 || IGC_IS_FLAG_ENABLED(ForceCSSIMD16)))
145+
{
146+
if (forcedSIMDSize == 0)
147+
m_Ctx->getModuleMetaData()->csInfo.forcedSIMDSize = 16;
148+
return 16;
149+
}
150+
151+
return forcedSIMDSize;
152+
}
153+
154+
int32_t JointMatrixFuncsResolutionPass::DefineKernelSIMDSize()
155+
{
156+
if (m_Ctx->platform.hasExecSize16DPAS())
157+
{
158+
if (IGC_IS_FLAG_ENABLED(EnableOCLSIMD16) && IGC_IS_FLAG_DISABLED(ForceCSSIMD32))
159+
return 16;
160+
if (IGC_IS_FLAG_ENABLED(EnableOCLSIMD32) && IGC_IS_FLAG_DISABLED(ForceCSSIMD16))
161+
return 32;
162+
std::string msg = "Sub group sizes supported by Joint Matrix for this platform are disabled by flags or non-supported sub group size forced.";
163+
m_Ctx->EmitError(msg.c_str(), nullptr);
164+
return 0;
165+
}
166+
if (IGC_IS_FLAG_ENABLED(EnableOCLSIMD32) && IGC_IS_FLAG_ENABLED(ForceCSSIMD32))
167+
{
168+
std::string msg = "Sub group size 32 forced by flags but not supported by Joint Matrix on this platform.";
169+
m_Ctx->EmitError(msg.c_str(), nullptr);
170+
return 0;
171+
}
172+
if (IGC_IS_FLAG_ENABLED(EnableOCLSIMD16) && IGC_IS_FLAG_ENABLED(ForceCSSIMD16))
173+
{
174+
std::string msg = "Sub group size 16 forced by flags but not supported by Joint Matrix on this platform.";
175+
m_Ctx->EmitError(msg.c_str(), nullptr);
176+
return 0;
177+
}
178+
return 8;
179+
}
180+
181+
bool JointMatrixFuncsResolutionPass::IsSIMDSizeValid(int32_t simdSize)
182+
{
183+
return ((m_Ctx->platform.hasExecSize16DPAS() && (simdSize == 16 || simdSize == 32)) ||
184+
(!m_Ctx->platform.hasExecSize16DPAS() && simdSize == 8));
185+
}
135186

187+
void JointMatrixFuncsResolutionPass::ForceKernelSIMDSize(Function *F, int32_t forcedSIMDSize)
188+
{
136189
Function *entryFunction = getEntryFunction(F);
137-
if (entryFunction) {
190+
if (entryFunction) // if can find entry function
191+
{
192+
IGCMD::FunctionInfoMetaDataHandle funcInfoMD = m_mdUtils->getFunctionsInfoItem(entryFunction);
193+
IGCMD::SubGroupSizeMetaDataHandle subGroupSize = funcInfoMD->getSubGroupSize();
194+
subGroupSize->setSIMD_size(forcedSIMDSize);
195+
}
196+
}
197+
198+
void JointMatrixFuncsResolutionPass::ResolveSIMDSize(Function *F)
199+
{
200+
if (m_SIMDSize != 0)
201+
return;
202+
203+
int32_t forcedSIMDSize = DetermineForcedSIMDSize();
204+
if (forcedSIMDSize != 0)
205+
{
206+
if (IsSIMDSizeValid(forcedSIMDSize))
207+
{
208+
m_SIMDSize = forcedSIMDSize;
209+
ForceKernelSIMDSize(F, m_SIMDSize);
210+
return;
211+
}
212+
// if forced and not ok for platform exit with error
213+
std::string msg = "Sub group size " + std::to_string(forcedSIMDSize) + " is forced by flags but not supported by Joint Matrix on this platform.";
214+
m_Ctx->EmitError(msg.c_str(), nullptr);
215+
return;
216+
}
217+
218+
// if not forced by driver of flags, check on entry function level
219+
Function *entryFunction = getEntryFunction(F);
220+
if (entryFunction) // if can find entry function
221+
{
138222
IGCMD::FunctionInfoMetaDataHandle funcInfoMD = m_mdUtils->getFunctionsInfoItem(entryFunction);
139223
IGCMD::SubGroupSizeMetaDataHandle subGroupSize = funcInfoMD->getSubGroupSize();
140224
if (subGroupSize->hasValue())
141-
m_SIMDSize = subGroupSize->getSIMD_size();
225+
{
226+
int32_t kernelSIMDSize = subGroupSize->getSIMD_size();
227+
if (kernelSIMDSize != 0)
228+
{
229+
if (IsSIMDSizeValid(kernelSIMDSize))
230+
{
231+
m_SIMDSize = kernelSIMDSize;
232+
return;
233+
}
234+
// if set on entry function level and not ok for this platform exit with error
235+
std::string msg = "Sub group size " + std::to_string(kernelSIMDSize) + " is forced by attribute but not supported by Joint Matrix on this platform.";
236+
m_Ctx->EmitError(msg.c_str(), nullptr);
237+
return;
238+
}
239+
}
240+
// if not set on entry function level, define ourselves
241+
m_SIMDSize = DefineKernelSIMDSize();
242+
// and set to entry level function
243+
subGroupSize->setSIMD_size(m_SIMDSize);
244+
return;
142245
}
143246

144247
// If no entry function found (it means that we could not detect that current function is called
145-
// from any kernel) or kernel function doesn't have sub group size requirement set, we anyway
146-
// will resolve function, just in case, using default sub group size.
147-
// In theory it may cause resolution conflicts if sub group sizes are mixed.
148-
if (m_SIMDSize == 0)
149-
m_SIMDSize = m_Ctx->platform.hasExecSize16DPAS() ? 16 : 8;
248+
// from any kernel), we anyway will resolve function, just in case, using default sub group size.
249+
m_SIMDSize = DefineKernelSIMDSize();
250+
// Force SIMD size if not set, as Joint Matrix need it to define numer of elements in WI
251+
m_Ctx->getModuleMetaData()->csInfo.forcedSIMDSize = (unsigned char)m_SIMDSize;
150252
}
151253

152254
bool JointMatrixFuncsResolutionPass::runOnFunction(Function& F)
@@ -178,7 +280,8 @@ bool JointMatrixFuncsResolutionPass::runOnFunction(Function& F)
178280
return !ResolvedValues.empty();
179281
}
180282

181-
static const char *CommonBIPrefix = "__builtin_spirv_";
283+
static const char *JointMatrixBIPrefix = "__builtin_spirv_OpJointMatrix";
284+
static const char *JointMatrixBISuffix = "JointMatrixINTEL_";
182285
static const char *JointMatrixLoadPrefx = "JointMatrixLoadINTEL";
183286
static const char *JointMatrixStorePrefx = "JointMatrixStoreINTEL";
184287
static const char *JointMatrixMadPrefx = "JointMatrixMadINTEL";
@@ -1661,7 +1764,7 @@ void JointMatrixFuncsResolutionPass::visitCallInst(CallInst& CI)
16611764
* future when returning and passing matrices by argument is
16621765
* supported also basic block terminators should be used as
16631766
* transformation starting point */
1664-
if (funcName.startswith(CommonBIPrefix)) {
1767+
if (funcName.startswith(JointMatrixBIPrefix) || funcName.contains(JointMatrixBISuffix)) {
16651768
ResolveSIMDSize(CI.getParent()->getParent());
16661769
ResolveCall(&CI);
16671770
return;

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,14 @@ namespace IGC
8787

8888
bool ValidateLoadStore
8989
(bool isLoad, unsigned operationLayout, const JointMatrixTypeDescription *desc, llvm::Value *ctx);
90+
91+
// SIMD Size helpers
9092
llvm::Function *getEntryFunction(llvm::Function *F);
9193
void ResolveSIMDSize(llvm::Function *F);
94+
int32_t DetermineForcedSIMDSize();
95+
int32_t DefineKernelSIMDSize();
96+
bool IsSIMDSizeValid(int32_t simdSize);
97+
void ForceKernelSIMDSize(llvm::Function *F, int32_t forcedSIMDSize);
9298

9399
llvm::ValueMap<llvm::Value *, llvm::Instruction *> PlaceholderInstructions;
94100
llvm::ValueMap<llvm::Value *, llvm::Value *> ResolvedValues;

IGC/Compiler/tests/JointMatrixFuncsResolutionPass/address-spaces.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,4 +211,4 @@ declare void @__builtin_spirv_OpJointMatrixStoreINTELacc_32x64_f32_p1i8_acc_32x6
211211
!5 = !{void (i8 addrspace(3)*, i8 addrspace(3)*)* @test_local, !1}
212212
!1 = !{!2, !3}
213213
!2 = !{!"function_type", i32 0}
214-
!3 = !{!"sub_group_size", i32 16}
214+
!3 = !{!"sub_group_size", i32 8}

IGC/Compiler/tests/JointMatrixFuncsResolutionPass/basic.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,4 @@ declare spir_func void @__builtin_spirv_OpJointMatrixStoreINTEL.8x16(i8*, %intel
7979
!0 = !{void (i32, i8*, i32*, i8*, i8*)* @test_jm, !1}
8080
!1 = !{!2, !3}
8181
!2 = !{!"function_type", i32 0}
82-
!3 = !{!"sub_group_size", i32 16}
82+
!3 = !{!"sub_group_size", i32 8}

0 commit comments

Comments
 (0)