Skip to content

Commit 5c28af4

Browse files
authored
[Matrix] Use FixedVectorType everywhere in LowerMatrixIntrinsics. NFC (#142316)
These matrix ops do not support scalable vectors, so we should be really explicit about that and avoid casting mistakes.
1 parent 3b4c51b commit 5c28af4

File tree

1 file changed

+22
-24
lines changed

1 file changed

+22
-24
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -383,25 +383,25 @@ class LowerMatrixIntrinsics {
383383
return Vectors.size();
384384
else {
385385
assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
386-
return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
386+
return getVectorTy()->getNumElements();
387387
}
388388
}
389389
unsigned getNumRows() const {
390390
if (isColumnMajor()) {
391391
assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
392-
return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
392+
return getVectorTy()->getNumElements();
393393
} else
394394
return Vectors.size();
395395
}
396396

397397
void addVector(Value *V) { Vectors.push_back(V); }
398-
VectorType *getColumnTy() {
398+
FixedVectorType *getColumnTy() {
399399
assert(isColumnMajor() && "only supported for column-major matrixes");
400400
return getVectorTy();
401401
}
402402

403-
VectorType *getVectorTy() const {
404-
return cast<VectorType>(Vectors[0]->getType());
403+
FixedVectorType *getVectorTy() const {
404+
return cast<FixedVectorType>(Vectors[0]->getType());
405405
}
406406

407407
iterator_range<SmallVector<Value *, 8>::iterator> columns() {
@@ -514,7 +514,7 @@ class LowerMatrixIntrinsics {
514514
: Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}
515515

516516
unsigned getNumOps(Type *VT) {
517-
assert(isa<VectorType>(VT) && "Expected vector type");
517+
assert(isa<FixedVectorType>(VT) && "Expected vector type");
518518
return getNumOps(VT->getScalarType(),
519519
cast<FixedVectorType>(VT)->getNumElements());
520520
}
@@ -540,10 +540,8 @@ class LowerMatrixIntrinsics {
540540
/// into vectors.
541541
MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
542542
IRBuilder<> &Builder) {
543-
VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
544-
assert(VType && "MatrixVal must be a vector type");
545-
assert(cast<FixedVectorType>(VType)->getNumElements() ==
546-
SI.NumRows * SI.NumColumns &&
543+
FixedVectorType *VType = cast<FixedVectorType>(MatrixVal->getType());
544+
assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
547545
"The vector size must match the number of matrix elements");
548546

549547
// Check if we lowered MatrixVal using shape information. In that case,
@@ -563,8 +561,7 @@ class LowerMatrixIntrinsics {
563561

564562
// Otherwise split MatrixVal.
565563
SmallVector<Value *, 16> SplitVecs;
566-
for (unsigned MaskStart = 0;
567-
MaskStart < cast<FixedVectorType>(VType)->getNumElements();
564+
for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
568565
MaskStart += SI.getStride()) {
569566
Value *V = Builder.CreateShuffleVector(
570567
MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
@@ -1157,7 +1154,7 @@ class LowerMatrixIntrinsics {
11571154
/// vectors.
11581155
MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
11591156
bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1160-
auto *VType = cast<VectorType>(Ty);
1157+
auto *VType = cast<FixedVectorType>(Ty);
11611158
Type *EltTy = VType->getElementType();
11621159
Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
11631160
Value *EltPtr = Ptr;
@@ -1239,7 +1236,7 @@ class LowerMatrixIntrinsics {
12391236
MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
12401237
MaybeAlign MAlign, Value *Stride, bool IsVolatile,
12411238
IRBuilder<> &Builder) {
1242-
auto VType = cast<VectorType>(Ty);
1239+
auto *VType = cast<FixedVectorType>(Ty);
12431240
Value *EltPtr = Ptr;
12441241
for (auto Vec : enumerate(StoreVal.vectors())) {
12451242
Value *GEP = computeVectorAddr(
@@ -1377,7 +1374,7 @@ class LowerMatrixIntrinsics {
13771374
Value *LHS = MatMul->getArgOperand(0);
13781375
Value *RHS = MatMul->getArgOperand(1);
13791376

1380-
Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
1377+
Type *ElementType = cast<FixedVectorType>(LHS->getType())->getElementType();
13811378
bool IsIntVec = ElementType->isIntegerTy();
13821379

13831380
// Floating point reductions require reassocation.
@@ -1475,7 +1472,7 @@ class LowerMatrixIntrinsics {
14751472
int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
14761473
InstructionCost ReductionCost =
14771474
TTI.getArithmeticReductionCost(
1478-
AddOpCode, cast<VectorType>(LHS->getType()),
1475+
AddOpCode, cast<FixedVectorType>(LHS->getType()),
14791476
IsIntVec ? std::nullopt : std::optional(FMF)) +
14801477
TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
14811478
InstructionCost SequentialAddCost =
@@ -1535,8 +1532,8 @@ class LowerMatrixIntrinsics {
15351532
Result = Builder.CreateAddReduce(Mul);
15361533
else {
15371534
Result = Builder.CreateFAddReduce(
1538-
ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
1539-
0.0),
1535+
ConstantFP::get(
1536+
cast<FixedVectorType>(LHS->getType())->getElementType(), 0.0),
15401537
Mul);
15411538
cast<Instruction>(Result)->setFastMathFlags(FMF);
15421539
}
@@ -1735,7 +1732,7 @@ class LowerMatrixIntrinsics {
17351732
const unsigned R = LShape.NumRows;
17361733
const unsigned C = RShape.NumColumns;
17371734
const unsigned M = LShape.NumColumns;
1738-
auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1735+
auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
17391736

17401737
const unsigned VF = std::max<unsigned>(
17411738
TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
@@ -1771,7 +1768,7 @@ class LowerMatrixIntrinsics {
17711768

17721769
void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
17731770
Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1774-
auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1771+
auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
17751772

17761773
// Create the main tiling loop nest.
17771774
TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
@@ -1842,7 +1839,7 @@ class LowerMatrixIntrinsics {
18421839
const unsigned R = LShape.NumRows;
18431840
const unsigned C = RShape.NumColumns;
18441841
const unsigned M = LShape.NumColumns;
1845-
auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1842+
auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
18461843

18471844
Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
18481845
Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
@@ -1914,7 +1911,8 @@ class LowerMatrixIntrinsics {
19141911
? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
19151912
: match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
19161913
IRBuilder<> Builder(MatMul);
1917-
auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
1914+
auto *EltType =
1915+
cast<FixedVectorType>(MatMul->getType())->getElementType();
19181916
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
19191917
ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
19201918
const unsigned R = LShape.NumRows;
@@ -2045,7 +2043,7 @@ class LowerMatrixIntrinsics {
20452043
/// Lowers llvm.matrix.multiply.
20462044
void LowerMultiply(CallInst *MatMul) {
20472045
IRBuilder<> Builder(MatMul);
2048-
auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
2046+
auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
20492047
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
20502048
ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
20512049

@@ -2073,7 +2071,7 @@ class LowerMatrixIntrinsics {
20732071
MatrixTy Result;
20742072
IRBuilder<> Builder(Inst);
20752073
Value *InputVal = Inst->getArgOperand(0);
2076-
VectorType *VectorTy = cast<VectorType>(InputVal->getType());
2074+
FixedVectorType *VectorTy = cast<FixedVectorType>(InputVal->getType());
20772075
ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
20782076
MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
20792077

0 commit comments

Comments
 (0)