@@ -383,25 +383,25 @@ class LowerMatrixIntrinsics {
383
383
return Vectors.size ();
384
384
else {
385
385
assert (Vectors.size () > 0 && " Cannot call getNumRows without columns" );
386
- return cast<FixedVectorType>(Vectors[ 0 ]-> getType () )->getNumElements ();
386
+ return getVectorTy ( )->getNumElements ();
387
387
}
388
388
}
389
389
unsigned getNumRows () const {
390
390
if (isColumnMajor ()) {
391
391
assert (Vectors.size () > 0 && " Cannot call getNumRows without columns" );
392
- return cast<FixedVectorType>(Vectors[ 0 ]-> getType () )->getNumElements ();
392
+ return getVectorTy ( )->getNumElements ();
393
393
} else
394
394
return Vectors.size ();
395
395
}
396
396
397
397
void addVector (Value *V) { Vectors.push_back (V); }
398
- VectorType *getColumnTy () {
398
+ FixedVectorType *getColumnTy () {
399
399
assert (isColumnMajor () && " only supported for column-major matrixes" );
400
400
return getVectorTy ();
401
401
}
402
402
403
- VectorType *getVectorTy () const {
404
- return cast<VectorType >(Vectors[0 ]->getType ());
403
+ FixedVectorType *getVectorTy () const {
404
+ return cast<FixedVectorType >(Vectors[0 ]->getType ());
405
405
}
406
406
407
407
iterator_range<SmallVector<Value *, 8 >::iterator> columns () {
@@ -514,7 +514,7 @@ class LowerMatrixIntrinsics {
514
514
: Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}
515
515
516
516
unsigned getNumOps (Type *VT) {
517
- assert (isa<VectorType >(VT) && " Expected vector type" );
517
+ assert (isa<FixedVectorType >(VT) && " Expected vector type" );
518
518
return getNumOps (VT->getScalarType (),
519
519
cast<FixedVectorType>(VT)->getNumElements ());
520
520
}
@@ -540,10 +540,8 @@ class LowerMatrixIntrinsics {
540
540
// / into vectors.
541
541
MatrixTy getMatrix (Value *MatrixVal, const ShapeInfo &SI,
542
542
IRBuilder<> &Builder) {
543
- VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType ());
544
- assert (VType && " MatrixVal must be a vector type" );
545
- assert (cast<FixedVectorType>(VType)->getNumElements () ==
546
- SI.NumRows * SI.NumColumns &&
543
+ FixedVectorType *VType = cast<FixedVectorType>(MatrixVal->getType ());
544
+ assert (VType->getNumElements () == SI.NumRows * SI.NumColumns &&
547
545
" The vector size must match the number of matrix elements" );
548
546
549
547
// Check if we lowered MatrixVal using shape information. In that case,
@@ -563,8 +561,7 @@ class LowerMatrixIntrinsics {
563
561
564
562
// Otherwise split MatrixVal.
565
563
SmallVector<Value *, 16 > SplitVecs;
566
- for (unsigned MaskStart = 0 ;
567
- MaskStart < cast<FixedVectorType>(VType)->getNumElements ();
564
+ for (unsigned MaskStart = 0 ; MaskStart < VType->getNumElements ();
568
565
MaskStart += SI.getStride ()) {
569
566
Value *V = Builder.CreateShuffleVector (
570
567
MatrixVal, createSequentialMask (MaskStart, SI.getStride (), 0 ),
@@ -1157,7 +1154,7 @@ class LowerMatrixIntrinsics {
1157
1154
// / vectors.
1158
1155
MatrixTy loadMatrix (Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1159
1156
bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1160
- auto *VType = cast<VectorType >(Ty);
1157
+ auto *VType = cast<FixedVectorType >(Ty);
1161
1158
Type *EltTy = VType->getElementType ();
1162
1159
Type *VecTy = FixedVectorType::get (EltTy, Shape.getStride ());
1163
1160
Value *EltPtr = Ptr;
@@ -1239,7 +1236,7 @@ class LowerMatrixIntrinsics {
1239
1236
MatrixTy storeMatrix (Type *Ty, MatrixTy StoreVal, Value *Ptr,
1240
1237
MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1241
1238
IRBuilder<> &Builder) {
1242
- auto VType = cast<VectorType >(Ty);
1239
+ auto * VType = cast<FixedVectorType >(Ty);
1243
1240
Value *EltPtr = Ptr;
1244
1241
for (auto Vec : enumerate(StoreVal.vectors ())) {
1245
1242
Value *GEP = computeVectorAddr (
@@ -1377,7 +1374,7 @@ class LowerMatrixIntrinsics {
1377
1374
Value *LHS = MatMul->getArgOperand (0 );
1378
1375
Value *RHS = MatMul->getArgOperand (1 );
1379
1376
1380
- Type *ElementType = cast<VectorType >(LHS->getType ())->getElementType ();
1377
+ Type *ElementType = cast<FixedVectorType >(LHS->getType ())->getElementType ();
1381
1378
bool IsIntVec = ElementType->isIntegerTy ();
1382
1379
1383
1380
// Floating point reductions require reassocation.
@@ -1475,7 +1472,7 @@ class LowerMatrixIntrinsics {
1475
1472
int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1476
1473
InstructionCost ReductionCost =
1477
1474
TTI.getArithmeticReductionCost (
1478
- AddOpCode, cast<VectorType >(LHS->getType ()),
1475
+ AddOpCode, cast<FixedVectorType >(LHS->getType ()),
1479
1476
IsIntVec ? std::nullopt : std::optional (FMF)) +
1480
1477
TTI.getArithmeticInstrCost (MulOpCode, LHS->getType ());
1481
1478
InstructionCost SequentialAddCost =
@@ -1535,8 +1532,8 @@ class LowerMatrixIntrinsics {
1535
1532
Result = Builder.CreateAddReduce (Mul);
1536
1533
else {
1537
1534
Result = Builder.CreateFAddReduce (
1538
- ConstantFP::get (cast<VectorType>(LHS-> getType ())-> getElementType (),
1539
- 0.0 ),
1535
+ ConstantFP::get (
1536
+ cast<FixedVectorType>(LHS-> getType ())-> getElementType (), 0.0 ),
1540
1537
Mul);
1541
1538
cast<Instruction>(Result)->setFastMathFlags (FMF);
1542
1539
}
@@ -1735,7 +1732,7 @@ class LowerMatrixIntrinsics {
1735
1732
const unsigned R = LShape.NumRows ;
1736
1733
const unsigned C = RShape.NumColumns ;
1737
1734
const unsigned M = LShape.NumColumns ;
1738
- auto *EltType = cast<VectorType >(MatMul->getType ())->getElementType ();
1735
+ auto *EltType = cast<FixedVectorType >(MatMul->getType ())->getElementType ();
1739
1736
1740
1737
const unsigned VF = std::max<unsigned >(
1741
1738
TTI.getRegisterBitWidth (TargetTransformInfo::RGK_FixedWidthVector)
@@ -1771,7 +1768,7 @@ class LowerMatrixIntrinsics {
1771
1768
1772
1769
void createTiledLoops (CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1773
1770
Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1774
- auto *EltType = cast<VectorType >(MatMul->getType ())->getElementType ();
1771
+ auto *EltType = cast<FixedVectorType >(MatMul->getType ())->getElementType ();
1775
1772
1776
1773
// Create the main tiling loop nest.
1777
1774
TileInfo TI (LShape.NumRows , RShape.NumColumns , LShape.NumColumns , TileSize);
@@ -1842,7 +1839,7 @@ class LowerMatrixIntrinsics {
1842
1839
const unsigned R = LShape.NumRows ;
1843
1840
const unsigned C = RShape.NumColumns ;
1844
1841
const unsigned M = LShape.NumColumns ;
1845
- auto *EltType = cast<VectorType >(MatMul->getType ())->getElementType ();
1842
+ auto *EltType = cast<FixedVectorType >(MatMul->getType ())->getElementType ();
1846
1843
1847
1844
Value *APtr = getNonAliasingPointer (LoadOp0, Store, MatMul);
1848
1845
Value *BPtr = getNonAliasingPointer (LoadOp1, Store, MatMul);
@@ -1914,7 +1911,8 @@ class LowerMatrixIntrinsics {
1914
1911
? match (B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value (T)))
1915
1912
: match (A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value (T)))) {
1916
1913
IRBuilder<> Builder (MatMul);
1917
- auto *EltType = cast<VectorType>(MatMul->getType ())->getElementType ();
1914
+ auto *EltType =
1915
+ cast<FixedVectorType>(MatMul->getType ())->getElementType ();
1918
1916
ShapeInfo LShape (MatMul->getArgOperand (2 ), MatMul->getArgOperand (3 ));
1919
1917
ShapeInfo RShape (MatMul->getArgOperand (3 ), MatMul->getArgOperand (4 ));
1920
1918
const unsigned R = LShape.NumRows ;
@@ -2045,7 +2043,7 @@ class LowerMatrixIntrinsics {
2045
2043
// / Lowers llvm.matrix.multiply.
2046
2044
void LowerMultiply (CallInst *MatMul) {
2047
2045
IRBuilder<> Builder (MatMul);
2048
- auto *EltType = cast<VectorType >(MatMul->getType ())->getElementType ();
2046
+ auto *EltType = cast<FixedVectorType >(MatMul->getType ())->getElementType ();
2049
2047
ShapeInfo LShape (MatMul->getArgOperand (2 ), MatMul->getArgOperand (3 ));
2050
2048
ShapeInfo RShape (MatMul->getArgOperand (3 ), MatMul->getArgOperand (4 ));
2051
2049
@@ -2073,7 +2071,7 @@ class LowerMatrixIntrinsics {
2073
2071
MatrixTy Result;
2074
2072
IRBuilder<> Builder (Inst);
2075
2073
Value *InputVal = Inst->getArgOperand (0 );
2076
- VectorType *VectorTy = cast<VectorType >(InputVal->getType ());
2074
+ FixedVectorType *VectorTy = cast<FixedVectorType >(InputVal->getType ());
2077
2075
ShapeInfo ArgShape (Inst->getArgOperand (1 ), Inst->getArgOperand (2 ));
2078
2076
MatrixTy InputMatrix = getMatrix (InputVal, ArgShape, Builder);
2079
2077
0 commit comments