@@ -535,6 +535,14 @@ void CMSimdCFLower::processFunction(Function *ArgF)
535
535
unsigned CMWidth = PredicatedSubroutines[F];
536
536
// Find the simd branches.
537
537
bool FoundSIMD = findSimdBranches (CMWidth);
538
+
539
+ // Create shuffle mask for EM adjustment
540
+ if (ShuffleMask.empty ()) {
541
+ auto I32Ty = Type::getInt32Ty (F->getContext ());
542
+ for (unsigned i = 0 ; i != 32 ; ++i)
543
+ ShuffleMask.push_back (ConstantInt::get (I32Ty, i));
544
+ }
545
+
538
546
if (CMWidth > 0 || FoundSIMD) {
539
547
// Determine which basic blocks need to be predicated.
540
548
determinePredicatedBlocks ();
@@ -555,10 +563,13 @@ void CMSimdCFLower::processFunction(Function *ArgF)
555
563
lowerSimdCF ();
556
564
lowerUnmaskOps ();
557
565
}
566
+
567
+ ShuffleMask.clear ();
558
568
SimdBranches.clear ();
559
569
PredicatedBlocks.clear ();
560
570
JoinPoints.clear ();
561
571
RMAddrs.clear ();
572
+ OriginalPred.clear ();
562
573
AlreadyPredicated.clear ();
563
574
}
564
575
@@ -1214,6 +1225,7 @@ unsigned CMSimdCFLower::deduceNumChannels(Instruction *SI) {
1214
1225
// If it's not a function call then check for a specific instruction
1215
1226
unsigned IID = GenXIntrinsic::getGenXIntrinsicID (CI);
1216
1227
switch (IID) {
1228
+ case GenXIntrinsic::genx_gather4_masked_scaled2:
1217
1229
case GenXIntrinsic::genx_gather4_scaled2: {
1218
1230
unsigned AddrElems = VCINTR::VectorType::getNumElements (
1219
1231
cast<VectorType>(CI->getOperand (4 )->getType ()));
@@ -1262,6 +1274,7 @@ void CMSimdCFLower::predicateStore(Instruction *SI, unsigned SimdWidth)
1262
1274
CallInst *WrRegionToPredicate = nullptr ;
1263
1275
Use *U = &SI->getOperandUse (0 );
1264
1276
Use *UseNeedsUpdate = nullptr ;
1277
+ Value *ExistingPred = nullptr ;
1265
1278
for (;;) {
1266
1279
if (auto BC = dyn_cast<BitCastInst>(V)) {
1267
1280
U = &BC->getOperandUse (0 );
@@ -1277,6 +1290,15 @@ void CMSimdCFLower::predicateStore(Instruction *SI, unsigned SimdWidth)
1277
1290
unsigned IID = GenXIntrinsic::getGenXIntrinsicID (WrRegion);
1278
1291
if (IID != GenXIntrinsic::genx_wrregioni
1279
1292
&& IID != GenXIntrinsic::genx_wrregionf) {
1293
+ // genx_gather4_masked_scaled2 is slightly different: it has predicate
1294
+ // operand and its users have to be predicated as well since it returns value
1295
+ // with size greater of execution size
1296
+ if (IID == GenXIntrinsic::genx_gather4_masked_scaled2) {
1297
+ assert (AlreadyPredicated.find (WrRegion) != AlreadyPredicated.end ());
1298
+ if (OriginalPred.count (WrRegion))
1299
+ ExistingPred = OriginalPred[WrRegion];
1300
+ break ;
1301
+ }
1280
1302
// Not wrregion. See if it is an intrinsic that has already been
1281
1303
// predicated; if so do not attempt to predicate the store.
1282
1304
if (AlreadyPredicated.find (WrRegion) != AlreadyPredicated.end ())
@@ -1361,7 +1383,19 @@ void CMSimdCFLower::predicateStore(Instruction *SI, unsigned SimdWidth)
1361
1383
Load = CallInst::Create (Fn, Addr, " .simdcfpred.vload" , SI);
1362
1384
}
1363
1385
Load->setDebugLoc (SI->getDebugLoc ());
1364
- auto EM = loadExecutionMask (SI, SimdWidth, NumChannels);
1386
+ Value *EM = loadExecutionMask (SI, SimdWidth);
1387
+
1388
+ // If there was a predicate already then update it with current EM
1389
+ if (ExistingPred) {
1390
+ EM = BinaryOperator::Create (
1391
+ Instruction::And, ExistingPred, EM,
1392
+ ExistingPred->getName () + " .and." + EM->getName (), SI);
1393
+ cast<Instruction>(EM)->setDebugLoc (SI->getDebugLoc ());
1394
+ }
1395
+
1396
+ // Replicate mask for each channel if needed
1397
+ EM = replicateMask (EM, SI, SimdWidth, NumChannels);
1398
+
1365
1399
auto Select = SelectInst::Create (EM, SI->getOperand (0 ), Load,
1366
1400
SI->getOperand (0 )->getName () + " .simdcfpred" , SI);
1367
1401
SI->setOperand (0 , Select);
@@ -1450,16 +1484,26 @@ void CMSimdCFLower::predicateScatterGather(CallInst *CI, unsigned SimdWidth,
1450
1484
{
1451
1485
Value *OldPred = CI->getArgOperand (PredOperandNum);
1452
1486
assert (OldPred->getType ()->getScalarType ()->isIntegerTy (1 ));
1453
- if (SimdWidth != VCINTR::VectorType::getNumElements (
1454
- cast<VectorType>(OldPred->getType ()))) {
1455
- DiagnosticInfoSimdCF::emit (CI, " mismatching SIMD width of scatter/gather inside SIMD control flow" );
1456
- return ;
1487
+ switch (GenXIntrinsic::getGenXIntrinsicID (CI)) {
1488
+ case GenXIntrinsic::genx_gather4_masked_scaled2:
1489
+ break ;
1490
+ default : {
1491
+ if (SimdWidth != VCINTR::VectorType::getNumElements (
1492
+ cast<VectorType>(OldPred->getType ()))) {
1493
+ DiagnosticInfoSimdCF::emit (
1494
+ CI,
1495
+ " mismatching SIMD width of scatter/gather inside SIMD control flow" );
1496
+ return ;
1497
+ }
1498
+ break ;
1499
+ }
1457
1500
}
1458
1501
Instruction *NewPred = loadExecutionMask (CI, SimdWidth);
1459
1502
if (auto C = dyn_cast<Constant>(OldPred))
1460
1503
if (C->isAllOnesValue ())
1461
1504
OldPred = nullptr ;
1462
1505
if (OldPred) {
1506
+ OriginalPred[CI] = OldPred;
1463
1507
auto And = BinaryOperator::Create (Instruction::And, OldPred, NewPred,
1464
1508
OldPred->getName () + " .and." + NewPred->getName (), CI);
1465
1509
And->setDebugLoc (CI->getDebugLoc ());
@@ -1496,6 +1540,7 @@ CallInst *CMSimdCFLower::predicateWrRegion(CallInst *WrR, unsigned SimdWidth)
1496
1540
if (!Pred)
1497
1541
Pred = EM;
1498
1542
else {
1543
+ OriginalPred[WrR] = Pred;
1499
1544
auto And = BinaryOperator::Create (Instruction::And, EM, Pred,
1500
1545
Pred->getName () + " .and." + EM->getName (), WrR);
1501
1546
And->setDebugLoc (WrR->getDebugLoc ());
@@ -1783,39 +1828,46 @@ CallInst *CMSimdCFLower::isSimdCFAny(Value *V)
1783
1828
return nullptr ;
1784
1829
}
1785
1830
1831
+ /* **********************************************************************
1832
+ * replicateMask : copy mask for provided number of channels using shufflevector
1833
+ */
1834
+ Value *CMSimdCFLower::replicateMask (Value *EM, Instruction *InsertBefore,
1835
+ unsigned SimdWidth, unsigned NumChannels) {
1836
+ // No need to replicate the mask for one channel
1837
+ if (NumChannels == 1 )
1838
+ return EM;
1839
+
1840
+ SmallVector<Constant *, 128 > ChannelMask{SimdWidth * NumChannels};
1841
+ for (unsigned i = 0 ; i < NumChannels; ++i)
1842
+ std::copy (ShuffleMask.begin (), ShuffleMask.begin () + SimdWidth,
1843
+ ChannelMask.begin () + SimdWidth * i);
1844
+ EM = new ShuffleVectorInst (
1845
+ EM, UndefValue::get (EM->getType ()), ConstantVector::get (ChannelMask),
1846
+ Twine (" ChannelEM" ) + Twine (SimdWidth), InsertBefore);
1847
+
1848
+ return EM;
1849
+ }
1850
+
1786
1851
/* **********************************************************************
1787
1852
* loadExecutionMask : create instruction to load EM
1788
1853
*/
1789
1854
Instruction *CMSimdCFLower::loadExecutionMask (Instruction *InsertBefore,
1790
- unsigned SimdWidth, unsigned NumChannels)
1791
- {
1855
+ unsigned SimdWidth) {
1792
1856
Instruction *EM =
1793
1857
new LoadInst (EMVar->getType ()->getPointerElementType (), EMVar,
1794
1858
EMVar->getName (), false /* isVolatile */ , InsertBefore);
1795
- EM-> setDebugLoc (InsertBefore-> getDebugLoc ());
1859
+
1796
1860
// If the simd width is not MAX_SIMD_CF_WIDTH, extract the part of EM we want.
1797
- if (NumChannels == 1 && SimdWidth == MAX_SIMD_CF_WIDTH)
1861
+ if (SimdWidth == MAX_SIMD_CF_WIDTH)
1798
1862
return EM;
1799
- if (ShuffleMask.empty ()) {
1800
- auto I32Ty = Type::getInt32Ty (F->getContext ());
1801
- for (unsigned i = 0 ; i != 32 ; ++i)
1802
- ShuffleMask.push_back (ConstantInt::get (I32Ty, i));
1803
- }
1804
- if (NumChannels == 1 ) {
1805
- ArrayRef<Constant *> Mask = ShuffleMask;
1806
- EM = new ShuffleVectorInst (EM, UndefValue::get (EM->getType ()),
1807
- ConstantVector::get (Mask.take_front (SimdWidth)),
1808
- Twine (" EM" ) + Twine (SimdWidth), InsertBefore);
1809
- } else {
1810
- SmallVector<Constant *, 128 > ChannelMask{SimdWidth * NumChannels};
1811
- for (unsigned i = 0 ; i < NumChannels; ++i)
1812
- std::copy (ShuffleMask.begin (), ShuffleMask.begin () + SimdWidth,
1813
- ChannelMask.begin () + SimdWidth * i);
1814
- EM = new ShuffleVectorInst (
1815
- EM, UndefValue::get (EM->getType ()), ConstantVector::get (ChannelMask),
1816
- Twine (" ChannelEM" ) + Twine (SimdWidth), InsertBefore);
1817
- }
1863
+
1864
+ ArrayRef<Constant *> Mask = ShuffleMask;
1865
+ EM = new ShuffleVectorInst (EM, UndefValue::get (EM->getType ()),
1866
+ ConstantVector::get (Mask.take_front (SimdWidth)),
1867
+ Twine (" EM" ) + Twine (SimdWidth), InsertBefore);
1868
+
1818
1869
EM->setDebugLoc (InsertBefore->getDebugLoc ());
1870
+
1819
1871
return EM;
1820
1872
}
1821
1873
0 commit comments