@@ -215,6 +215,7 @@ class GenXLowering : public FunctionPass {
215
215
bool runOnFunction (Function &F);
216
216
217
217
private:
218
+ bool translateSLMOWord (CallInst* CI, unsigned IID);
218
219
bool splitGatherScatter (CallInst *CI, unsigned IID);
219
220
bool processTwoAddressOpnd (CallInst *CI);
220
221
bool processInst (Instruction *Inst);
@@ -1059,6 +1060,218 @@ bool GenXLowering::splitGatherScatter(CallInst *CI, unsigned IID) {
1059
1060
return true ;
1060
1061
}
1061
1062
1063
+ static Constant* getConstVector (Type* ITy, unsigned int n, unsigned int step) {
1064
+ std::vector<Constant*> vConsts;
1065
+ unsigned v = 0 ;
1066
+ for (unsigned i = 0 ; i < n; i++) {
1067
+ vConsts.push_back (ConstantInt::get (ITy, v));
1068
+ v += step;
1069
+ }
1070
+ return ConstantVector::get (vConsts);
1071
+ }
1072
+
1073
+
1074
+ /* **********************************************************************
1075
+ * translateSLMOWord : lower SLM OWord load/store to gathers/scatters
1076
+ * on legacy platform such as SKL.
1077
+ *
1078
+ * We only support the cases of 1,2,4,8 oword cases
1079
+ */
1080
+ bool GenXLowering::translateSLMOWord (CallInst* CI, unsigned IID) {
1081
+ LLVMContext& CTX = CI->getContext ();
1082
+ auto CIntTy = IntegerType::getInt32Ty (CTX);
1083
+ const DebugLoc& DL = CI->getDebugLoc ();
1084
+ switch (IID) {
1085
+ case GenXIntrinsic::genx_oword_ld:
1086
+ case GenXIntrinsic::genx_oword_ld_unaligned: {
1087
+ constexpr unsigned BtiIdx = 1 ;
1088
+ constexpr unsigned AddrIdx = 2 ;
1089
+ Value* BtiV = CI->getArgOperand (BtiIdx);
1090
+ // only slm need this lowering
1091
+ if (!isa<ConstantInt>(BtiV))
1092
+ return false ;
1093
+ if (cast<ConstantInt>(BtiV)->getZExtValue () !=
1094
+ visa::ReservedSurfaceIndex::RSI_Slm)
1095
+ return false ;
1096
+
1097
+ IRBuilder<> Builder (CI);
1098
+ Value* AddrV = CI->getArgOperand (AddrIdx);
1099
+ if (IID == GenXIntrinsic::genx_oword_ld) {
1100
+ AddrV = Builder.CreateShl (AddrV, llvm::ConstantInt::get (AddrV->getType (), 4 ));
1101
+ }
1102
+ auto OrigVT = cast<VectorType>(CI->getType ());
1103
+ unsigned EltSize = OrigVT->getScalarSizeInBits ();
1104
+ unsigned EltCount = OrigVT->getNumElements ();
1105
+ // 1-oword is 16 bytes, using simd4 dword gather-scaled
1106
+ // 2-oword is 32 bytes, using simd8 dword gather-scaled
1107
+ // 4-oword is 64 bytes, using simd16 dword gather-scaled
1108
+ // 8-oword is 128 bytes, using 2*simd16 dword gather-scaled
1109
+ unsigned DWordCnt = (EltSize * EltCount) / 32 ;
1110
+ assert (DWordCnt == 4 || DWordCnt == 8 || DWordCnt == 16 || DWordCnt == 32 );
1111
+ unsigned SimdWidth = (DWordCnt == 32 ) ? 16 : DWordCnt;
1112
+ auto NewVT = IGCLLVM::FixedVectorType::get (CIntTy, DWordCnt);
1113
+ auto GatherVT = IGCLLVM::FixedVectorType::get (CIntTy, SimdWidth);
1114
+ // generate gather-scaled
1115
+ auto VOffset = getConstVector (CIntTy, SimdWidth, 4 );
1116
+ // create constant for predicate
1117
+ auto PredVTy = IGCLLVM::FixedVectorType::get (IntegerType::getInt1Ty (CTX), SimdWidth);
1118
+ auto OnePredV = Constant::getAllOnesValue (PredVTy);
1119
+ auto ScaleC = ConstantInt::get (Type::getInt16Ty (CTX), 0 );
1120
+ std::string IntrName =
1121
+ std::string (GenXIntrinsic::getGenXIntrinsicPrefix ()) + " gather.scaled" ;
1122
+ auto ID = GenXIntrinsic::lookupGenXIntrinsicID (IntrName);
1123
+ // crease constant for num-blocks, 2 means 4-bytes
1124
+ auto NumBlksC = ConstantInt::get (CIntTy, 2 );
1125
+ // create the intrinsic call
1126
+ Function* NewFDecl = GenXIntrinsic::getGenXDeclaration (
1127
+ CI->getModule (), ID, { GatherVT, PredVTy, VOffset->getType () });
1128
+ Instruction* NewInst = nullptr ;
1129
+ if (DWordCnt == SimdWidth) {
1130
+ NewInst = IntrinsicInst::Create (NewFDecl,
1131
+ { OnePredV, NumBlksC, ScaleC, BtiV, AddrV,
1132
+ VOffset, UndefValue::get (GatherVT) },
1133
+ CI->getName () + " .gather" , CI);
1134
+ NewInst->setDebugLoc (DL);
1135
+ LLVM_DEBUG (dbgs () << " SLM OWord Load:\n " );
1136
+ LLVM_DEBUG (CI->dump ());
1137
+ LLVM_DEBUG (dbgs () << " Translated to gather:\n " );
1138
+ LLVM_DEBUG (NewInst->dump ());
1139
+ }
1140
+ else { // need to two gathers for 8 owords
1141
+ // 1st gather
1142
+ auto New1st = IntrinsicInst::Create (NewFDecl,
1143
+ { OnePredV, NumBlksC, ScaleC, BtiV, AddrV,
1144
+ VOffset, UndefValue::get (GatherVT) },
1145
+ CI->getName () + " .gather1" , CI);
1146
+ New1st->setDebugLoc (DL);
1147
+ // 2nd gather
1148
+ AddrV = Builder.CreateAdd (AddrV, llvm::ConstantInt::get (AddrV->getType (), 64 ));
1149
+ auto New2nd = IntrinsicInst::Create (NewFDecl,
1150
+ { OnePredV, NumBlksC, ScaleC, BtiV, AddrV,
1151
+ VOffset, UndefValue::get (GatherVT) },
1152
+ CI->getName () + " .gather2" , CI);
1153
+ New2nd->setDebugLoc (DL);
1154
+ // write region, 1st half
1155
+ Region R (NewVT);
1156
+ R.Width = SimdWidth;
1157
+ R.NumElements = SimdWidth;
1158
+ R.Stride = 1 ;
1159
+ R.VStride = 0 ;
1160
+ R.Offset = 0 ;
1161
+ auto PartialV = R.createWrRegion (UndefValue::get (NewVT), New1st, " " , CI, CI->getDebugLoc ());
1162
+ // write region, 2nd half
1163
+ R.Offset = 64 ;
1164
+ NewInst = R.createWrRegion (PartialV, New2nd, " " , CI, CI->getDebugLoc ());
1165
+ LLVM_DEBUG (dbgs () << " SLM OWord Load:\n " );
1166
+ LLVM_DEBUG (CI->dump ());
1167
+ LLVM_DEBUG (dbgs () << " Translated to gather:\n " );
1168
+ LLVM_DEBUG (New1st->dump ());
1169
+ LLVM_DEBUG (New2nd->dump ());
1170
+ }
1171
+ // cast back if required
1172
+ Value* Casted = NewInst;
1173
+ if (NewVT != OrigVT)
1174
+ Casted = CastInst::CreateBitOrPointerCast (
1175
+ Casted, OrigVT, Casted->getName () + VALUE_NAME (" .cast" ), CI);
1176
+ CI->replaceAllUsesWith (Casted);
1177
+ ToErase.push_back (CI);
1178
+ return true ;
1179
+ }
1180
+ case GenXIntrinsic::genx_oword_st: {
1181
+ constexpr unsigned DataIdx = 2 ;
1182
+ constexpr unsigned AddrIdx = 1 ;
1183
+ constexpr unsigned BtiIdx = 0 ;
1184
+ Value* BtiV = CI->getArgOperand (BtiIdx);
1185
+ // Only slm need this lowering
1186
+ if (!isa<ConstantInt>(BtiV))
1187
+ return false ;
1188
+ if (cast<ConstantInt>(BtiV)->getZExtValue () !=
1189
+ visa::ReservedSurfaceIndex::RSI_Slm)
1190
+ return false ;
1191
+
1192
+ IRBuilder<> Builder (CI);
1193
+ Value* AddrV = CI->getArgOperand (AddrIdx);
1194
+ AddrV = Builder.CreateShl (AddrV, llvm::ConstantInt::get (AddrV->getType (), 4 ));
1195
+
1196
+ Value* Datum = CI->getArgOperand (DataIdx);
1197
+ auto OrigVT = cast<VectorType>(Datum->getType ());
1198
+ unsigned EltSize = OrigVT->getScalarSizeInBits ();
1199
+ unsigned EltCount = OrigVT->getNumElements ();
1200
+ // 1-oword is 16 bytes, using simd4 dword scatter-scaled
1201
+ // 2-oword is 32 bytes, using simd8 dword scatter-scaled
1202
+ // 4-oword is 64 bytes, using simd16 dword scatter-scaled
1203
+ // 8-oword is 128 bytes, using 2*simd16 dword scatter-scaled
1204
+ unsigned DWordCnt = (EltSize * EltCount) / 32 ;
1205
+ assert (DWordCnt == 4 || DWordCnt == 8 || DWordCnt == 16 || DWordCnt == 32 );
1206
+ auto NewVT = IGCLLVM::FixedVectorType::get (CIntTy, DWordCnt);
1207
+ IGC_ASSERT_MESSAGE (CastInst::isBitCastable (NewVT, OrigVT),
1208
+ " We expect resulting vectors to be bitcastable" );
1209
+ if (NewVT != OrigVT)
1210
+ Datum = CastInst::CreateBitOrPointerCast (
1211
+ Datum, NewVT, Datum->getName () + VALUE_NAME (" .cast" ), CI);
1212
+ unsigned SimdWidth = (DWordCnt == 32 ) ? 16 : DWordCnt;
1213
+
1214
+ // generate scatter-scaled
1215
+ auto VOffset = getConstVector (CIntTy, SimdWidth, 4 );
1216
+ // create constant for predicate
1217
+ auto PredVTy = IGCLLVM::FixedVectorType::get (IntegerType::getInt1Ty (CTX), SimdWidth);
1218
+ auto OnePredV = Constant::getAllOnesValue (PredVTy);
1219
+ auto ScaleC = ConstantInt::get (Type::getInt16Ty (CTX), 0 );
1220
+ // create constant for num-blocks
1221
+ auto NumBlksC = ConstantInt::get (CIntTy, 2 );
1222
+ std::string IntrName =
1223
+ std::string (GenXIntrinsic::getGenXIntrinsicPrefix ()) + " scatter.scaled" ;
1224
+ auto ID = GenXIntrinsic::lookupGenXIntrinsicID (IntrName);
1225
+ // create the intrinsic call
1226
+ auto ScatterVT = IGCLLVM::FixedVectorType::get (CIntTy, SimdWidth);
1227
+ Function* NewFDecl = GenXIntrinsic::getGenXDeclaration (
1228
+ CI->getModule (), ID, { PredVTy, VOffset->getType (), ScatterVT });
1229
+ if (DWordCnt == SimdWidth) {
1230
+ // create one scatter
1231
+ auto NewInst = Builder.CreateCall (
1232
+ NewFDecl, { OnePredV, NumBlksC, ScaleC, BtiV, AddrV, VOffset, Datum });
1233
+ NewInst->setDebugLoc (DL);
1234
+ LLVM_DEBUG (dbgs () << " SLM OWord Store:\n " );
1235
+ LLVM_DEBUG (CI->dump ());
1236
+ LLVM_DEBUG (dbgs () << " Translated to scatter:\n " );
1237
+ LLVM_DEBUG (NewInst->dump ());
1238
+ }
1239
+ else { // 8-oword (i.e 32 dword) case
1240
+ // scatter the 1st 16 dwords
1241
+ // read region then scatter
1242
+ Region R (ScatterVT);
1243
+ R.Width = SimdWidth;
1244
+ R.NumElements = SimdWidth;
1245
+ R.Stride = 1 ;
1246
+ R.VStride = 0 ;
1247
+ R.Offset = 0 ;
1248
+ auto Datum1st = R.createRdRegion (Datum, " " , CI, CI->getDebugLoc ());
1249
+ auto New1st = Builder.CreateCall (
1250
+ NewFDecl, { OnePredV, NumBlksC, ScaleC, BtiV, AddrV, VOffset, Datum1st });
1251
+ New1st->setDebugLoc (DL);
1252
+ // scatter the 2nd 16 dwords
1253
+ // read region then scatter
1254
+ AddrV = Builder.CreateAdd (AddrV, llvm::ConstantInt::get (AddrV->getType (), 64 ));
1255
+ R.Offset = 64 ;
1256
+ auto Datum2nd = R.createRdRegion (Datum, " " , CI, CI->getDebugLoc ());
1257
+ auto New2nd = Builder.CreateCall (
1258
+ NewFDecl, { OnePredV, NumBlksC, ScaleC, BtiV, AddrV, VOffset, Datum2nd });
1259
+ New2nd->setDebugLoc (DL);
1260
+ LLVM_DEBUG (dbgs () << " SLM OWord Store:\n " );
1261
+ LLVM_DEBUG (CI->dump ());
1262
+ LLVM_DEBUG (dbgs () << " Translated to scatter:\n " );
1263
+ LLVM_DEBUG (New1st->dump ());
1264
+ LLVM_DEBUG (New2nd->dump ());
1265
+ }
1266
+ ToErase.push_back (CI);
1267
+ return true ;
1268
+ }
1269
+ default :
1270
+ break ;
1271
+ }
1272
+ return false ;
1273
+ }
1274
+
1062
1275
1063
1276
/* **********************************************************************
1064
1277
* generatePrecicatedWrrForNewLoad : Generate predicated wrr if result
@@ -1133,6 +1346,14 @@ bool GenXLowering::processInst(Instruction *Inst) {
1133
1346
if (Function *Callee = CI->getCalledFunction ()) {
1134
1347
IntrinsicID = GenXIntrinsic::getAnyIntrinsicID (Callee);
1135
1348
IGC_ASSERT (CI->getNumArgOperands () < GenXIntrinsicInfo::OPNDMASK);
1349
+ }
1350
+ if (ST) {
1351
+ // use gather/scatter to implement SLM oword load/store on
1352
+ // legacy platforms
1353
+ if (!ST->isCNLplus ()) {
1354
+ if (translateSLMOWord (CI, IntrinsicID))
1355
+ return true ;
1356
+ }
1136
1357
}
1137
1358
// split gather/scatter/atomic into the width legal to the target
1138
1359
if (splitGatherScatter (CI, IntrinsicID))
0 commit comments