@@ -344,9 +344,10 @@ unsigned GCNTTIImpl::getMinVectorRegisterBitWidth() const {
344
344
unsigned GCNTTIImpl::getMaximumVF (unsigned ElemWidth, unsigned Opcode) const {
345
345
if (Opcode == Instruction::Load || Opcode == Instruction::Store)
346
346
return 32 * 4 / ElemWidth;
347
- return (ElemWidth == 16 && ST->has16BitInsts ()) ? 2
348
- : (ElemWidth == 32 && ST->hasPackedFP32Ops ()) ? 2
349
- : 1 ;
347
+ return ElemWidth == 8 ? 4
348
+ : (ElemWidth == 16 && ST->has16BitInsts ()) ? 2
349
+ : (ElemWidth == 32 && ST->hasPackedFP32Ops ()) ? 2
350
+ : 1 ;
350
351
}
351
352
352
353
unsigned GCNTTIImpl::getLoadVectorFactor (unsigned VF, unsigned LoadSize,
@@ -537,6 +538,12 @@ InstructionCost GCNTTIImpl::getArithmeticInstrCost(
537
538
538
539
MVT::SimpleValueType SLT = LT.second .getScalarType ().SimpleTy ;
539
540
541
+ VectorType *VecTy = dyn_cast<VectorType>(Ty);
542
+ InstructionCost LTTypeCost = LT.first ;
543
+ if (VecTy &&
544
+ VecTy->getElementType () == IntegerType::getInt8Ty (VecTy->getContext ()))
545
+ LTTypeCost = (((LT.first - 1 ) / 4 ) + 1 );
546
+
540
547
switch (ISD) {
541
548
case ISD::SHL:
542
549
case ISD::SRL:
@@ -548,7 +555,7 @@ InstructionCost GCNTTIImpl::getArithmeticInstrCost(
548
555
NElts = (NElts + 1 ) / 2 ;
549
556
550
557
// i32
551
- return getFullRateInstrCost () * LT. first * NElts;
558
+ return getFullRateInstrCost () * LTTypeCost * NElts;
552
559
case ISD::ADD:
553
560
case ISD::SUB:
554
561
case ISD::AND:
@@ -562,7 +569,7 @@ InstructionCost GCNTTIImpl::getArithmeticInstrCost(
562
569
if (ST->has16BitInsts () && SLT == MVT::i16 )
563
570
NElts = (NElts + 1 ) / 2 ;
564
571
565
- return LT. first * NElts * getFullRateInstrCost ();
572
+ return LTTypeCost * NElts * getFullRateInstrCost ();
566
573
case ISD::MUL: {
567
574
const int QuarterRateCost = getQuarterRateInstrCost (CostKind);
568
575
if (SLT == MVT::i64 ) {
@@ -574,7 +581,7 @@ InstructionCost GCNTTIImpl::getArithmeticInstrCost(
574
581
NElts = (NElts + 1 ) / 2 ;
575
582
576
583
// i32
577
- return QuarterRateCost * NElts * LT. first ;
584
+ return QuarterRateCost * NElts * LTTypeCost ;
578
585
}
579
586
case ISD::FMUL:
580
587
// Check possible fuse {fadd|fsub}(a,fmul(b,c)) and return zero cost for
@@ -1423,3 +1430,27 @@ void GCNTTIImpl::collectKernelLaunchBounds(
1423
1430
LB.push_back ({" amdgpu-waves-per-eu[0]" , WavesPerEU.first });
1424
1431
LB.push_back ({" amdgpu-waves-per-eu[1]" , WavesPerEU.second });
1425
1432
}
1433
+
1434
+ InstructionCost GCNTTIImpl::getMemoryOpCost (unsigned Opcode, Type *Src,
1435
+ Align Alignment,
1436
+ unsigned AddressSpace,
1437
+ TTI::TargetCostKind CostKind,
1438
+ TTI::OperandValueInfo OpInfo,
1439
+ const Instruction *I) {
1440
+ if (VectorType *VecTy = dyn_cast<VectorType>(Src))
1441
+ if (Opcode == Instruction::Load && VecTy->getElementType () == IntegerType::getInt8Ty (VecTy->getContext ())) {
1442
+ unsigned ElementCount = VecTy->getElementCount ().getFixedValue ();
1443
+ return ((ElementCount - 1 ) / 4 ) + 1 ;
1444
+ }
1445
+ return BaseT::getMemoryOpCost (Opcode, Src, Alignment, AddressSpace, CostKind,
1446
+ OpInfo, I);
1447
+ }
1448
+
1449
+ unsigned GCNTTIImpl::getNumberOfParts (Type *Tp) {
1450
+ if (VectorType *VecTy = dyn_cast<VectorType>(Tp))
1451
+ if (VecTy->getElementType () == IntegerType::getInt8Ty (VecTy->getContext ())) {
1452
+ unsigned ElementCount = VecTy->getElementCount ().getFixedValue ();
1453
+ return ((ElementCount - 1 ) / 4 ) + 1 ;
1454
+ }
1455
+ return BaseT::getNumberOfParts (Tp);
1456
+ }
0 commit comments