@@ -485,10 +485,14 @@ static bool CheckAllocaUsesInternal(Instruction *I) {
485
485
llvm::dyn_cast<llvm::BitCastInst>(*use_it)) {
486
486
if (pBitCast->use_empty ())
487
487
continue ;
488
- Type *baseT =
489
- GetBaseType (pBitCast->getType ()->getPointerElementType (), nullptr );
490
- Type *sourceType = GetBaseType (
491
- pBitCast->getOperand (0 )->getType ()->getPointerElementType (), nullptr );
488
+ Type *baseT = GetBaseType (
489
+ pBitCast->getType ()->getScalarType ()->getPointerElementType (),
490
+ nullptr );
491
+ Type *sourceType = GetBaseType (pBitCast->getOperand (0 )
492
+ ->getType ()
493
+ ->getScalarType ()
494
+ ->getPointerElementType (),
495
+ nullptr );
492
496
IGC_ASSERT (sourceType);
493
497
// either the point-to-element-type is the same or
494
498
// the point-to-element-type is the byte or a function pointer
@@ -648,10 +652,11 @@ void TransposeHelper::EraseDeadCode() {
648
652
649
653
void TransposeHelper::handleBCInst (BitCastInst &BC, GenericVectorIndex Idx) {
650
654
m_toBeRemoved.push_back (&BC);
651
- Type *DstDerefTy =
652
- GetBaseType ( BC.getType ()->getPointerElementType (), nullptr );
655
+ Type *DstDerefTy = GetBaseType (
656
+ BC.getType ()-> getScalarType ()->getPointerElementType (), nullptr );
653
657
Type *SrcDerefTy = GetBaseType (
654
- BC.getOperand (0 )->getType ()->getPointerElementType (), nullptr );
658
+ BC.getOperand (0 )->getType ()->getScalarType ()->getPointerElementType (),
659
+ nullptr );
655
660
IGC_ASSERT (DstDerefTy);
656
661
IGC_ASSERT (SrcDerefTy);
657
662
// either the point-to-element-type is the same or
@@ -855,27 +860,52 @@ void TransposeHelper::handlePHINode(PHINode *pPhi, GenericVectorIndex Idx,
855
860
handleAllocaSources (*pPhi, {NewPhi, Idx.ElementSizeInBits });
856
861
}
857
862
863
+ // Loads vector and casts it if necessary.
864
+ // \p CastTo describes vector element type to cast to.
865
+ template <typename FolderT = ConstantFolder>
866
+ Instruction *loadAndCastVector (AllocaInst &VecAlloca, Type &CastTo,
867
+ IRBuilder<FolderT> &IRB) {
868
+ auto *LoadVecAlloca = IRB.CreateLoad (&VecAlloca);
869
+ auto *AllocatedElemTy = LoadVecAlloca->getType ()->getScalarType ();
870
+ bool IsFuncPointer =
871
+ CastTo.isPointerTy () && CastTo.getPointerElementType ()->isFunctionTy ();
872
+ if (AllocatedElemTy == &CastTo || IsFuncPointer)
873
+ return LoadVecAlloca;
874
+ auto AllocatedWidth = cast<IGCLLVM::FixedVectorType>(LoadVecAlloca->getType ())
875
+ ->getNumElements ();
876
+ IGC_ASSERT (AllocatedElemTy->getScalarSizeInBits () >=
877
+ CastTo.getScalarSizeInBits ());
878
+ IGC_ASSERT (CastTo.getScalarSizeInBits ());
879
+ IGC_ASSERT ((AllocatedElemTy->getScalarSizeInBits () %
880
+ CastTo.getScalarSizeInBits ()) == 0 );
881
+ auto CastedWidth = AllocatedWidth * (AllocatedElemTy->getScalarSizeInBits () /
882
+ CastTo.getScalarSizeInBits ());
883
+ return cast<Instruction>(IRB.CreateBitCast (
884
+ LoadVecAlloca, IGCLLVM::FixedVectorType::get (&CastTo, CastedWidth),
885
+ " post.load.bc" ));
886
+ }
887
+
888
+ // Casts \p NewValue if its type doesn't correspond to allocated vector type,
889
+ // then stores the value.
890
+ template <typename FolderT = ConstantFolder>
891
+ Instruction *castAndStoreVector (AllocaInst &VecAlloca, Value &NewValue,
892
+ IRBuilder<FolderT> &IRB) {
893
+ auto *CastedValue = &NewValue;
894
+ if (VecAlloca.getAllocatedType () != NewValue.getType ())
895
+ CastedValue = IRB.CreateBitCast (&NewValue, VecAlloca.getAllocatedType (),
896
+ NewValue.getName () + " .pre.store.bc" );
897
+ return IRB.CreateStore (CastedValue, &VecAlloca);
898
+ }
899
+
858
900
void TransposeHelperPromote::handleLoadInst (LoadInst *pLoad,
859
901
Value *pScalarizedIdx) {
860
902
IGC_ASSERT (pLoad->isSimple ());
861
903
IRBuilder<> IRB (pLoad);
862
- Value *pLoadVecAlloca = IRB.CreateLoad (pVecAlloca);
863
904
auto LdTy = pLoad->getType ()->getScalarType ();
864
- auto VETy = pLoadVecAlloca->getType ()->getScalarType ();
865
- auto ReadIn = pLoadVecAlloca;
905
+ auto *ReadIn = loadAndCastVector (*pVecAlloca, *LdTy, IRB);
866
906
bool IsFuncPointer = pLoad->getPointerOperandType ()->isPointerTy () &&
867
907
pLoad->getPointerOperandType ()->getPointerElementType ()->isPointerTy () &&
868
908
pLoad->getPointerOperandType ()->getPointerElementType ()->getPointerElementType ()->isFunctionTy ();
869
- // do the type-casting if necessary
870
- if (VETy != LdTy && !IsFuncPointer) {
871
- auto VLen = cast<VectorType>(pLoadVecAlloca->getType ())->getNumElements ();
872
- IGC_ASSERT (VETy->getScalarSizeInBits () >= LdTy->getScalarSizeInBits ());
873
- IGC_ASSERT (LdTy->getScalarSizeInBits ());
874
- IGC_ASSERT ((VETy->getScalarSizeInBits () % LdTy->getScalarSizeInBits ()) == 0 );
875
- VLen = VLen * (VETy->getScalarSizeInBits () / LdTy->getScalarSizeInBits ());
876
- ReadIn =
877
- IRB.CreateBitCast (ReadIn, IGCLLVM::FixedVectorType::get (LdTy, VLen));
878
- }
879
909
if (IsFuncPointer) {
880
910
Region R (
881
911
IGCLLVM::FixedVectorType::get (
@@ -891,7 +921,7 @@ void TransposeHelperPromote::handleLoadInst(LoadInst *pLoad,
891
921
pScalarizedIdx = IRB.CreateZExtOrTrunc (pScalarizedIdx, Type::getInt16Ty (pLoad->getContext ()));
892
922
}
893
923
R.Indirect = pScalarizedIdx;
894
- auto *Result = R.createRdRegion (pLoadVecAlloca , pLoad->getName (), pLoad,
924
+ auto *Result = R.createRdRegion (ReadIn , pLoad->getName (), pLoad,
895
925
pLoad->getDebugLoc (), true );
896
926
if (!Result->getType ()->isPointerTy ()) {
897
927
auto *BC =
@@ -931,25 +961,13 @@ void TransposeHelperPromote::handleStoreInst(StoreInst *pStore,
931
961
IGC_ASSERT (pStore->isSimple ());
932
962
IRBuilder<> IRB (pStore);
933
963
llvm::Value *pStoreVal = pStore->getValueOperand ();
934
- llvm::Value *pLoadVecAlloca = IRB.CreateLoad (pVecAlloca);
935
- llvm::Value *WriteOut = pLoadVecAlloca;
936
964
auto *StTy = pStoreVal->getType ()->getScalarType ();
937
- auto *VETy = pLoadVecAlloca->getType ()->getScalarType ();
938
- // do the type-casting if necessary
965
+ Value *WriteOut = loadAndCastVector (*pVecAlloca, *StTy, IRB);
939
966
940
967
bool IsFuncPointerStore =
941
968
(isFuncPointerVec (pStoreVal) ||
942
969
(pStoreVal->getType ()->isPointerTy () &&
943
970
pStoreVal->getType ()->getPointerElementType ()->isFunctionTy ()));
944
- if (VETy != StTy && !IsFuncPointerStore) {
945
- auto VLen = cast<VectorType>(pLoadVecAlloca->getType ())->getNumElements ();
946
- IGC_ASSERT (VETy->getScalarSizeInBits () >= StTy->getScalarSizeInBits ());
947
- IGC_ASSERT (StTy->getScalarSizeInBits ());
948
- IGC_ASSERT ((VETy->getScalarSizeInBits () % StTy->getScalarSizeInBits ()) == 0 );
949
- VLen = VLen * (VETy->getScalarSizeInBits () / StTy->getScalarSizeInBits ());
950
- WriteOut =
951
- IRB.CreateBitCast (WriteOut, IGCLLVM::FixedVectorType::get (StTy, VLen));
952
- }
953
971
if (IsFuncPointerStore) {
954
972
auto *NewStoreVal = pStoreVal;
955
973
IGC_ASSERT (cast<VectorType>(pVecAlloca->getType ()->getPointerElementType ())->getElementType ()->isIntegerTy (64 ));
@@ -1000,10 +1018,7 @@ void TransposeHelperPromote::handleStoreInst(StoreInst *pStore,
1000
1018
WriteOut =
1001
1019
IRB.CreateInsertElement (WriteOut, pStoreVal, ScalarizedIdx.Index );
1002
1020
}
1003
- // cast the vector type back if necessary
1004
- if (VETy != StTy)
1005
- WriteOut = IRB.CreateBitCast (WriteOut, pLoadVecAlloca->getType ());
1006
- IRB.CreateStore (WriteOut, pVecAlloca);
1021
+ castAndStoreVector (*pVecAlloca, *WriteOut, IRB);
1007
1022
pStore->eraseFromParent ();
1008
1023
}
1009
1024
@@ -1152,9 +1167,9 @@ void TransposeHelperPromote::handleLLVMGather(IntrinsicInst *pInst,
1152
1167
Value *pScalarizedIdx) {
1153
1168
IRBuilder<> IRB (pInst);
1154
1169
IGC_ASSERT (pInst->getType ()->isVectorTy ());
1155
- Value *pLoadVecAlloca = IRB.CreateLoad (pVecAlloca);
1156
1170
auto N = cast<VectorType>(pInst->getType ())->getNumElements ();
1157
1171
auto ElemType = cast<VectorType>(pInst->getType ())->getElementType ();
1172
+ Value *LoadVecAlloca = loadAndCastVector (*pVecAlloca, *ElemType, IRB);
1158
1173
1159
1174
// A vector load
1160
1175
// %v = <2 x float> gather %pred, %vector_of_ptr, %old_value
@@ -1192,8 +1207,8 @@ void TransposeHelperPromote::handleLLVMGather(IntrinsicInst *pInst,
1192
1207
R.VStride = 0 ;
1193
1208
}
1194
1209
Value *Result =
1195
- R.createRdRegion (pLoadVecAlloca , pInst->getName (), pInst /* InsertBefore*/ ,
1196
- pInst->getDebugLoc (), true /* AllowScalar*/ );
1210
+ R.createRdRegion (LoadVecAlloca , pInst->getName (), pInst /* InsertBefore*/ ,
1211
+ pInst->getDebugLoc (), true /* AllowScalar*/ );
1197
1212
1198
1213
// if old-value is not undefined and predicate is not all-one,
1199
1214
// create a select auto OldVal = pInst->getArgOperand(3);
@@ -1216,20 +1231,22 @@ void TransposeHelperPromote::handleLLVMScatter(llvm::IntrinsicInst *pInst,
1216
1231
llvm::Value *pScalarizedIdx) {
1217
1232
// Add Store instruction to remove list
1218
1233
IRBuilder<> IRB (pInst);
1219
- llvm::Value *pStoreVal = pInst->getArgOperand (3 );
1220
- llvm::Value *pLoadVecAlloca = IRB.CreateLoad (pVecAlloca);
1221
- IGC_ASSERT (pStoreVal->getType ()->isVectorTy ());
1222
- auto N = cast<VectorType>(pStoreVal->getType ())->getNumElements ();
1223
- auto ElemType = cast<VectorType>(pStoreVal->getType ())->getElementType ();
1234
+ Value *StoredValue = pInst->getArgOperand (0 );
1235
+ IGC_ASSERT (StoredValue->getType ()->isVectorTy ());
1236
+ auto N =
1237
+ cast<IGCLLVM::FixedVectorType>(StoredValue->getType ())->getNumElements ();
1238
+ auto *ElemType =
1239
+ cast<IGCLLVM::FixedVectorType>(StoredValue->getType ())->getElementType ();
1240
+ Value *LoadVecAlloca = loadAndCastVector (*pVecAlloca, *ElemType, IRB);
1224
1241
// A vector scatter
1225
- // scatter %pred , %ptr, %offset , %newvalue
1242
+ // scatter %newvalue , %ptr, i32 align , %pred
1226
1243
// becomes
1227
- // %w = load <32 x float> *%ptr1
1228
- // %w1 = <32 x float> wrregion %w, newvalue, %offset , %pred
1229
- // store <32 x float> %w1, <32 x float>* %ptr1
1244
+ // %w = load %vec.alloca
1245
+ // %w1 = wrregion %w, % newvalue, %indexed.ptr , %pred
1246
+ // store %w1, %vec.alloca
1230
1247
1231
1248
// Create the new wrregion
1232
- Region R (pStoreVal) ;
1249
+ Region R{StoredValue} ;
1233
1250
int64_t v0 = 0 ;
1234
1251
int64_t diff = 0 ;
1235
1252
// pScalarizedIdx is an indice of element, so
@@ -1259,12 +1276,12 @@ void TransposeHelperPromote::handleLLVMScatter(llvm::IntrinsicInst *pInst,
1259
1276
R.Stride = 0 ;
1260
1277
R.VStride = 0 ;
1261
1278
}
1262
- R.Mask = pInst->getArgOperand (0 );
1279
+ R.Mask = pInst->getArgOperand (3 );
1263
1280
auto NewInst = cast<Instruction>(
1264
- R.createWrRegion (pLoadVecAlloca, pStoreVal , pInst->getName (),
1265
- pInst /* InsertBefore*/ , pInst->getDebugLoc ()));
1281
+ R.createWrRegion (LoadVecAlloca, StoredValue , pInst->getName (),
1282
+ pInst /* InsertBefore*/ , pInst->getDebugLoc ()));
1266
1283
1267
- IRB. CreateStore ( NewInst, pVecAlloca );
1284
+ castAndStoreVector (*pVecAlloca, * NewInst, IRB );
1268
1285
pInst->eraseFromParent ();
1269
1286
}
1270
1287
0 commit comments