Skip to content

Commit 70caa31

Browse files
committed
[Matrix] Refactor shape info computation (NFCI).
Factor our forward shape computation for a given instruction. This allows re-use in a follow-up fix.
1 parent 5c9f768 commit 70caa31

File tree

1 file changed

+105
-98
lines changed

1 file changed

+105
-98
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 105 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,109 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
192192
return VecStart;
193193
}
194194

195+
namespace {
196+
struct ShapeInfo {
197+
unsigned NumRows;
198+
unsigned NumColumns;
199+
200+
bool IsColumnMajor;
201+
202+
ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
203+
: NumRows(NumRows), NumColumns(NumColumns),
204+
IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
205+
206+
ShapeInfo(Value *NumRows, Value *NumColumns)
207+
: ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
208+
cast<ConstantInt>(NumColumns)->getZExtValue()) {}
209+
210+
bool operator==(const ShapeInfo &other) {
211+
return NumRows == other.NumRows && NumColumns == other.NumColumns;
212+
}
213+
bool operator!=(const ShapeInfo &other) { return !(*this == other); }
214+
215+
/// Returns true if shape-information is defined, meaning both dimensions
216+
/// are != 0.
217+
operator bool() const {
218+
assert(NumRows == 0 || NumColumns != 0);
219+
return NumRows != 0;
220+
}
221+
222+
unsigned getStride() const {
223+
if (IsColumnMajor)
224+
return NumRows;
225+
return NumColumns;
226+
}
227+
228+
unsigned getNumVectors() const {
229+
if (IsColumnMajor)
230+
return NumColumns;
231+
return NumRows;
232+
}
233+
234+
/// Returns the transposed shape.
235+
ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
236+
};
237+
} // namespace
238+
239+
static bool isUniformShape(Value *V) {
240+
Instruction *I = dyn_cast<Instruction>(V);
241+
if (!I)
242+
return true;
243+
244+
switch (I->getOpcode()) {
245+
case Instruction::FAdd:
246+
case Instruction::FSub:
247+
case Instruction::FMul: // Scalar multiply.
248+
case Instruction::FNeg:
249+
case Instruction::Add:
250+
case Instruction::Mul:
251+
case Instruction::Sub:
252+
return true;
253+
default:
254+
return false;
255+
}
256+
}
257+
258+
/// Return the ShapeInfo for the result of \p I, it it can be determined.
259+
static std::optional<ShapeInfo>
260+
computeShapeInfoForInst(Instruction *I,
261+
const ValueMap<Value *, ShapeInfo> &ShapeMap) {
262+
Value *M;
263+
Value *N;
264+
Value *K;
265+
if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>(
266+
m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K))))
267+
return ShapeInfo(M, K);
268+
if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M),
269+
m_Value(N)))) {
270+
// Flip dimensions.
271+
return ShapeInfo(N, M);
272+
}
273+
if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
274+
m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
275+
m_Value(N))))
276+
return ShapeInfo(N, M);
277+
if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
278+
m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
279+
return ShapeInfo(M, N);
280+
Value *MatrixA;
281+
if (match(I, m_Store(m_Value(MatrixA), m_Value()))) {
282+
auto OpShape = ShapeMap.find(MatrixA);
283+
if (OpShape != ShapeMap.end())
284+
return OpShape->second;
285+
}
286+
287+
if (isUniformShape(I)) {
288+
// Find the first operand that has a known shape and use that.
289+
for (auto &Op : I->operands()) {
290+
auto OpShape = ShapeMap.find(Op.get());
291+
if (OpShape != ShapeMap.end())
292+
return OpShape->second;
293+
}
294+
}
295+
return std::nullopt;
296+
}
297+
195298
/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
196299
///
197300
/// Currently, the lowering for each matrix intrinsic is done as follows:
@@ -383,48 +486,6 @@ class LowerMatrixIntrinsics {
383486
}
384487
};
385488

386-
struct ShapeInfo {
387-
unsigned NumRows;
388-
unsigned NumColumns;
389-
390-
bool IsColumnMajor;
391-
392-
ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
393-
: NumRows(NumRows), NumColumns(NumColumns),
394-
IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
395-
396-
ShapeInfo(Value *NumRows, Value *NumColumns)
397-
: ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
398-
cast<ConstantInt>(NumColumns)->getZExtValue()) {}
399-
400-
bool operator==(const ShapeInfo &other) {
401-
return NumRows == other.NumRows && NumColumns == other.NumColumns;
402-
}
403-
bool operator!=(const ShapeInfo &other) { return !(*this == other); }
404-
405-
/// Returns true if shape-information is defined, meaning both dimensions
406-
/// are != 0.
407-
operator bool() const {
408-
assert(NumRows == 0 || NumColumns != 0);
409-
return NumRows != 0;
410-
}
411-
412-
unsigned getStride() const {
413-
if (IsColumnMajor)
414-
return NumRows;
415-
return NumColumns;
416-
}
417-
418-
unsigned getNumVectors() const {
419-
if (IsColumnMajor)
420-
return NumColumns;
421-
return NumRows;
422-
}
423-
424-
/// Returns the transposed shape.
425-
ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
426-
};
427-
428489
/// Maps instructions to their shape information. The shape information
429490
/// describes the shape to be used while lowering. This matches the shape of
430491
/// the result value of the instruction, with the only exceptions being store
@@ -554,25 +615,6 @@ class LowerMatrixIntrinsics {
554615
return true;
555616
}
556617

557-
bool isUniformShape(Value *V) {
558-
Instruction *I = dyn_cast<Instruction>(V);
559-
if (!I)
560-
return true;
561-
562-
switch (I->getOpcode()) {
563-
case Instruction::FAdd:
564-
case Instruction::FSub:
565-
case Instruction::FMul: // Scalar multiply.
566-
case Instruction::FNeg:
567-
case Instruction::Add:
568-
case Instruction::Mul:
569-
case Instruction::Sub:
570-
return true;
571-
default:
572-
return false;
573-
}
574-
}
575-
576618
/// Returns true if shape information can be used for \p V. The supported
577619
/// instructions must match the instructions that can be lowered by this pass.
578620
bool supportsShapeInfo(Value *V) {
@@ -610,43 +652,8 @@ class LowerMatrixIntrinsics {
610652

611653
// New entry, set the value and insert operands
612654
bool Propagate = false;
613-
614-
Value *MatrixA;
615-
Value *MatrixB;
616-
Value *M;
617-
Value *N;
618-
Value *K;
619-
if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
620-
m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
621-
m_Value(N), m_Value(K)))) {
622-
Propagate = setShapeInfo(Inst, {M, K});
623-
} else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
624-
m_Value(MatrixA), m_Value(M), m_Value(N)))) {
625-
// Flip dimensions.
626-
Propagate = setShapeInfo(Inst, {N, M});
627-
} else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
628-
m_Value(MatrixA), m_Value(), m_Value(),
629-
m_Value(), m_Value(M), m_Value(N)))) {
630-
Propagate = setShapeInfo(Inst, {N, M});
631-
} else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
632-
m_Value(), m_Value(), m_Value(), m_Value(M),
633-
m_Value(N)))) {
634-
Propagate = setShapeInfo(Inst, {M, N});
635-
} else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
636-
auto OpShape = ShapeMap.find(MatrixA);
637-
if (OpShape != ShapeMap.end())
638-
setShapeInfo(Inst, OpShape->second);
639-
continue;
640-
} else if (isUniformShape(Inst)) {
641-
// Find the first operand that has a known shape and use that.
642-
for (auto &Op : Inst->operands()) {
643-
auto OpShape = ShapeMap.find(Op.get());
644-
if (OpShape != ShapeMap.end()) {
645-
Propagate |= setShapeInfo(Inst, OpShape->second);
646-
break;
647-
}
648-
}
649-
}
655+
if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
656+
Propagate = setShapeInfo(Inst, *SI);
650657

651658
if (Propagate) {
652659
NewWorkList.push_back(Inst);

0 commit comments

Comments
 (0)