Skip to content

Commit 332f230

Browse files
Gang Y Chenigcbot
authored andcommitted
Need to check Bti Value is constant first
1 parent 364808d commit 332f230

File tree

1 file changed

+221
-0
lines changed

1 file changed

+221
-0
lines changed

IGC/VectorCompiler/lib/GenXCodeGen/GenXLowering.cpp

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ class GenXLowering : public FunctionPass {
215215
bool runOnFunction(Function &F);
216216

217217
private:
218+
bool translateSLMOWord(CallInst* CI, unsigned IID);
218219
bool splitGatherScatter(CallInst *CI, unsigned IID);
219220
bool processTwoAddressOpnd(CallInst *CI);
220221
bool processInst(Instruction *Inst);
@@ -1059,6 +1060,218 @@ bool GenXLowering::splitGatherScatter(CallInst *CI, unsigned IID) {
10591060
return true;
10601061
}
10611062

1063+
static Constant* getConstVector(Type* ITy, unsigned int n, unsigned int step) {
1064+
std::vector<Constant*> vConsts;
1065+
unsigned v = 0;
1066+
for (unsigned i = 0; i < n; i++) {
1067+
vConsts.push_back(ConstantInt::get(ITy, v));
1068+
v += step;
1069+
}
1070+
return ConstantVector::get(vConsts);
1071+
}
1072+
1073+
1074+
/***********************************************************************
1075+
* translateSLMOWord : lower SLM OWord load/store to gathers/scatters
1076+
* on legacy platform such as SKL.
1077+
*
1078+
* We only support the cases of 1,2,4,8 oword cases
1079+
*/
1080+
bool GenXLowering::translateSLMOWord(CallInst* CI, unsigned IID) {
1081+
LLVMContext& CTX = CI->getContext();
1082+
auto CIntTy = IntegerType::getInt32Ty(CTX);
1083+
const DebugLoc& DL = CI->getDebugLoc();
1084+
switch (IID) {
1085+
case GenXIntrinsic::genx_oword_ld:
1086+
case GenXIntrinsic::genx_oword_ld_unaligned: {
1087+
constexpr unsigned BtiIdx = 1;
1088+
constexpr unsigned AddrIdx = 2;
1089+
Value* BtiV = CI->getArgOperand(BtiIdx);
1090+
// only slm need this lowering
1091+
if (!isa<ConstantInt>(BtiV))
1092+
return false;
1093+
if (cast<ConstantInt>(BtiV)->getZExtValue() !=
1094+
visa::ReservedSurfaceIndex::RSI_Slm)
1095+
return false;
1096+
1097+
IRBuilder<> Builder(CI);
1098+
Value* AddrV = CI->getArgOperand(AddrIdx);
1099+
if (IID == GenXIntrinsic::genx_oword_ld) {
1100+
AddrV = Builder.CreateShl(AddrV, llvm::ConstantInt::get(AddrV->getType(), 4));
1101+
}
1102+
auto OrigVT = cast<VectorType>(CI->getType());
1103+
unsigned EltSize = OrigVT->getScalarSizeInBits();
1104+
unsigned EltCount = OrigVT->getNumElements();
1105+
// 1-oword is 16 bytes, using simd4 dword gather-scaled
1106+
// 2-oword is 32 bytes, using simd8 dword gather-scaled
1107+
// 4-oword is 64 bytes, using simd16 dword gather-scaled
1108+
// 8-oword is 128 bytes, using 2*simd16 dword gather-scaled
1109+
unsigned DWordCnt = (EltSize * EltCount) / 32;
1110+
assert(DWordCnt == 4 || DWordCnt == 8 || DWordCnt == 16 || DWordCnt == 32);
1111+
unsigned SimdWidth = (DWordCnt == 32) ? 16 : DWordCnt;
1112+
auto NewVT = IGCLLVM::FixedVectorType::get(CIntTy, DWordCnt);
1113+
auto GatherVT = IGCLLVM::FixedVectorType::get(CIntTy, SimdWidth);
1114+
// generate gather-scaled
1115+
auto VOffset = getConstVector(CIntTy, SimdWidth, 4);
1116+
// create constant for predicate
1117+
auto PredVTy = IGCLLVM::FixedVectorType::get(IntegerType::getInt1Ty(CTX), SimdWidth);
1118+
auto OnePredV = Constant::getAllOnesValue(PredVTy);
1119+
auto ScaleC = ConstantInt::get(Type::getInt16Ty(CTX), 0);
1120+
std::string IntrName =
1121+
std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) + "gather.scaled";
1122+
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName);
1123+
// crease constant for num-blocks, 2 means 4-bytes
1124+
auto NumBlksC = ConstantInt::get(CIntTy, 2);
1125+
// create the intrinsic call
1126+
Function* NewFDecl = GenXIntrinsic::getGenXDeclaration(
1127+
CI->getModule(), ID, { GatherVT, PredVTy, VOffset->getType() });
1128+
Instruction* NewInst = nullptr;
1129+
if (DWordCnt == SimdWidth) {
1130+
NewInst = IntrinsicInst::Create(NewFDecl,
1131+
{ OnePredV, NumBlksC, ScaleC, BtiV, AddrV,
1132+
VOffset, UndefValue::get(GatherVT) },
1133+
CI->getName() + ".gather", CI);
1134+
NewInst->setDebugLoc(DL);
1135+
LLVM_DEBUG(dbgs() << "SLM OWord Load:\n");
1136+
LLVM_DEBUG(CI->dump());
1137+
LLVM_DEBUG(dbgs() << "Translated to gather:\n");
1138+
LLVM_DEBUG(NewInst->dump());
1139+
}
1140+
else { // need to two gathers for 8 owords
1141+
// 1st gather
1142+
auto New1st = IntrinsicInst::Create(NewFDecl,
1143+
{ OnePredV, NumBlksC, ScaleC, BtiV, AddrV,
1144+
VOffset, UndefValue::get(GatherVT) },
1145+
CI->getName() + ".gather1", CI);
1146+
New1st->setDebugLoc(DL);
1147+
// 2nd gather
1148+
AddrV = Builder.CreateAdd(AddrV, llvm::ConstantInt::get(AddrV->getType(), 64));
1149+
auto New2nd = IntrinsicInst::Create(NewFDecl,
1150+
{ OnePredV, NumBlksC, ScaleC, BtiV, AddrV,
1151+
VOffset, UndefValue::get(GatherVT) },
1152+
CI->getName() + ".gather2", CI);
1153+
New2nd->setDebugLoc(DL);
1154+
// write region, 1st half
1155+
Region R(NewVT);
1156+
R.Width = SimdWidth;
1157+
R.NumElements = SimdWidth;
1158+
R.Stride = 1;
1159+
R.VStride = 0;
1160+
R.Offset = 0;
1161+
auto PartialV = R.createWrRegion(UndefValue::get(NewVT), New1st, "", CI, CI->getDebugLoc());
1162+
// write region, 2nd half
1163+
R.Offset = 64;
1164+
NewInst = R.createWrRegion(PartialV, New2nd, "", CI, CI->getDebugLoc());
1165+
LLVM_DEBUG(dbgs() << "SLM OWord Load:\n");
1166+
LLVM_DEBUG(CI->dump());
1167+
LLVM_DEBUG(dbgs() << "Translated to gather:\n");
1168+
LLVM_DEBUG(New1st->dump());
1169+
LLVM_DEBUG(New2nd->dump());
1170+
}
1171+
// cast back if required
1172+
Value* Casted = NewInst;
1173+
if (NewVT != OrigVT)
1174+
Casted = CastInst::CreateBitOrPointerCast(
1175+
Casted, OrigVT, Casted->getName() + VALUE_NAME(".cast"), CI);
1176+
CI->replaceAllUsesWith(Casted);
1177+
ToErase.push_back(CI);
1178+
return true;
1179+
}
1180+
case GenXIntrinsic::genx_oword_st: {
1181+
constexpr unsigned DataIdx = 2;
1182+
constexpr unsigned AddrIdx = 1;
1183+
constexpr unsigned BtiIdx = 0;
1184+
Value* BtiV = CI->getArgOperand(BtiIdx);
1185+
// Only slm need this lowering
1186+
if (!isa<ConstantInt>(BtiV))
1187+
return false;
1188+
if (cast<ConstantInt>(BtiV)->getZExtValue() !=
1189+
visa::ReservedSurfaceIndex::RSI_Slm)
1190+
return false;
1191+
1192+
IRBuilder<> Builder(CI);
1193+
Value* AddrV = CI->getArgOperand(AddrIdx);
1194+
AddrV = Builder.CreateShl(AddrV, llvm::ConstantInt::get(AddrV->getType(), 4));
1195+
1196+
Value* Datum = CI->getArgOperand(DataIdx);
1197+
auto OrigVT = cast<VectorType>(Datum->getType());
1198+
unsigned EltSize = OrigVT->getScalarSizeInBits();
1199+
unsigned EltCount = OrigVT->getNumElements();
1200+
// 1-oword is 16 bytes, using simd4 dword scatter-scaled
1201+
// 2-oword is 32 bytes, using simd8 dword scatter-scaled
1202+
// 4-oword is 64 bytes, using simd16 dword scatter-scaled
1203+
// 8-oword is 128 bytes, using 2*simd16 dword scatter-scaled
1204+
unsigned DWordCnt = (EltSize * EltCount) / 32;
1205+
assert(DWordCnt == 4 || DWordCnt == 8 || DWordCnt == 16 || DWordCnt == 32);
1206+
auto NewVT = IGCLLVM::FixedVectorType::get(CIntTy, DWordCnt);
1207+
IGC_ASSERT_MESSAGE(CastInst::isBitCastable(NewVT, OrigVT),
1208+
"We expect resulting vectors to be bitcastable");
1209+
if (NewVT != OrigVT)
1210+
Datum = CastInst::CreateBitOrPointerCast(
1211+
Datum, NewVT, Datum->getName() + VALUE_NAME(".cast"), CI);
1212+
unsigned SimdWidth = (DWordCnt == 32) ? 16 : DWordCnt;
1213+
1214+
// generate scatter-scaled
1215+
auto VOffset = getConstVector(CIntTy, SimdWidth, 4);
1216+
// create constant for predicate
1217+
auto PredVTy = IGCLLVM::FixedVectorType::get(IntegerType::getInt1Ty(CTX), SimdWidth);
1218+
auto OnePredV = Constant::getAllOnesValue(PredVTy);
1219+
auto ScaleC = ConstantInt::get(Type::getInt16Ty(CTX), 0);
1220+
// create constant for num-blocks
1221+
auto NumBlksC = ConstantInt::get(CIntTy, 2);
1222+
std::string IntrName =
1223+
std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) + "scatter.scaled";
1224+
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName);
1225+
// create the intrinsic call
1226+
auto ScatterVT = IGCLLVM::FixedVectorType::get(CIntTy, SimdWidth);
1227+
Function* NewFDecl = GenXIntrinsic::getGenXDeclaration(
1228+
CI->getModule(), ID, { PredVTy, VOffset->getType(), ScatterVT });
1229+
if (DWordCnt == SimdWidth) {
1230+
// create one scatter
1231+
auto NewInst = Builder.CreateCall(
1232+
NewFDecl, { OnePredV, NumBlksC, ScaleC, BtiV, AddrV, VOffset, Datum });
1233+
NewInst->setDebugLoc(DL);
1234+
LLVM_DEBUG(dbgs() << "SLM OWord Store:\n");
1235+
LLVM_DEBUG(CI->dump());
1236+
LLVM_DEBUG(dbgs() << "Translated to scatter:\n");
1237+
LLVM_DEBUG(NewInst->dump());
1238+
}
1239+
else { // 8-oword (i.e 32 dword) case
1240+
// scatter the 1st 16 dwords
1241+
// read region then scatter
1242+
Region R(ScatterVT);
1243+
R.Width = SimdWidth;
1244+
R.NumElements = SimdWidth;
1245+
R.Stride = 1;
1246+
R.VStride = 0;
1247+
R.Offset = 0;
1248+
auto Datum1st = R.createRdRegion(Datum, "", CI, CI->getDebugLoc());
1249+
auto New1st = Builder.CreateCall(
1250+
NewFDecl, { OnePredV, NumBlksC, ScaleC, BtiV, AddrV, VOffset, Datum1st });
1251+
New1st->setDebugLoc(DL);
1252+
// scatter the 2nd 16 dwords
1253+
// read region then scatter
1254+
AddrV = Builder.CreateAdd(AddrV, llvm::ConstantInt::get(AddrV->getType(), 64));
1255+
R.Offset = 64;
1256+
auto Datum2nd = R.createRdRegion(Datum, "", CI, CI->getDebugLoc());
1257+
auto New2nd = Builder.CreateCall(
1258+
NewFDecl, { OnePredV, NumBlksC, ScaleC, BtiV, AddrV, VOffset, Datum2nd });
1259+
New2nd->setDebugLoc(DL);
1260+
LLVM_DEBUG(dbgs() << "SLM OWord Store:\n");
1261+
LLVM_DEBUG(CI->dump());
1262+
LLVM_DEBUG(dbgs() << "Translated to scatter:\n");
1263+
LLVM_DEBUG(New1st->dump());
1264+
LLVM_DEBUG(New2nd->dump());
1265+
}
1266+
ToErase.push_back(CI);
1267+
return true;
1268+
}
1269+
default:
1270+
break;
1271+
}
1272+
return false;
1273+
}
1274+
10621275

10631276
/***********************************************************************
10641277
* generatePrecicatedWrrForNewLoad : Generate predicated wrr if result
@@ -1133,6 +1346,14 @@ bool GenXLowering::processInst(Instruction *Inst) {
11331346
if (Function *Callee = CI->getCalledFunction()) {
11341347
IntrinsicID = GenXIntrinsic::getAnyIntrinsicID(Callee);
11351348
IGC_ASSERT(CI->getNumArgOperands() < GenXIntrinsicInfo::OPNDMASK);
1349+
}
1350+
if (ST) {
1351+
// use gather/scatter to implement SLM oword load/store on
1352+
// legacy platforms
1353+
if (!ST->isCNLplus()) {
1354+
if (translateSLMOWord(CI, IntrinsicID))
1355+
return true;
1356+
}
11361357
}
11371358
// split gather/scatter/atomic into the width legal to the target
11381359
if (splitGatherScatter(CI, IntrinsicID))

0 commit comments

Comments
 (0)