Skip to content

Pick fix for matrix transpose hoisting. #8175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 118 additions & 103 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,109 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
}

namespace {
struct ShapeInfo {
unsigned NumRows;
unsigned NumColumns;

bool IsColumnMajor;

ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
: NumRows(NumRows), NumColumns(NumColumns),
IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}

ShapeInfo(Value *NumRows, Value *NumColumns)
: ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
cast<ConstantInt>(NumColumns)->getZExtValue()) {}

bool operator==(const ShapeInfo &other) {
return NumRows == other.NumRows && NumColumns == other.NumColumns;
}
bool operator!=(const ShapeInfo &other) { return !(*this == other); }

/// Returns true if shape-information is defined, meaning both dimensions
/// are != 0.
operator bool() const {
assert(NumRows == 0 || NumColumns != 0);
return NumRows != 0;
}

unsigned getStride() const {
if (IsColumnMajor)
return NumRows;
return NumColumns;
}

unsigned getNumVectors() const {
if (IsColumnMajor)
return NumColumns;
return NumRows;
}

/// Returns the transposed shape.
ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
};
} // namespace

static bool isUniformShape(Value *V) {
Instruction *I = dyn_cast<Instruction>(V);
if (!I)
return true;

switch (I->getOpcode()) {
case Instruction::FAdd:
case Instruction::FSub:
case Instruction::FMul: // Scalar multiply.
case Instruction::FNeg:
case Instruction::Add:
case Instruction::Mul:
case Instruction::Sub:
return true;
default:
return false;
}
}

/// Return the ShapeInfo for the result of \p I, it it can be determined.
static std::optional<ShapeInfo>
computeShapeInfoForInst(Instruction *I,
const ValueMap<Value *, ShapeInfo> &ShapeMap) {
Value *M;
Value *N;
Value *K;
if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>(
m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K))))
return ShapeInfo(M, K);
if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M),
m_Value(N)))) {
// Flip dimensions.
return ShapeInfo(N, M);
}
if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
m_Value(N))))
return ShapeInfo(N, M);
if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
return ShapeInfo(M, N);
Value *MatrixA;
if (match(I, m_Store(m_Value(MatrixA), m_Value()))) {
auto OpShape = ShapeMap.find(MatrixA);
if (OpShape != ShapeMap.end())
return OpShape->second;
}

if (isUniformShape(I)) {
// Find the first operand that has a known shape and use that.
for (auto &Op : I->operands()) {
auto OpShape = ShapeMap.find(Op.get());
if (OpShape != ShapeMap.end())
return OpShape->second;
}
}
return std::nullopt;
}

/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
///
/// Currently, the lowering for each matrix intrinsic is done as follows:
Expand Down Expand Up @@ -390,48 +493,6 @@ class LowerMatrixIntrinsics {
}
};

struct ShapeInfo {
unsigned NumRows;
unsigned NumColumns;

bool IsColumnMajor;

ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
: NumRows(NumRows), NumColumns(NumColumns),
IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}

ShapeInfo(Value *NumRows, Value *NumColumns)
: ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
cast<ConstantInt>(NumColumns)->getZExtValue()) {}

bool operator==(const ShapeInfo &other) {
return NumRows == other.NumRows && NumColumns == other.NumColumns;
}
bool operator!=(const ShapeInfo &other) { return !(*this == other); }

/// Returns true if shape-information is defined, meaning both dimensions
/// are != 0.
operator bool() const {
assert(NumRows == 0 || NumColumns != 0);
return NumRows != 0;
}

unsigned getStride() const {
if (IsColumnMajor)
return NumRows;
return NumColumns;
}

unsigned getNumVectors() const {
if (IsColumnMajor)
return NumColumns;
return NumRows;
}

/// Returns the transposed shape.
ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
};

/// Maps instructions to their shape information. The shape information
/// describes the shape to be used while lowering. This matches the shape of
/// the result value of the instruction, with the only exceptions being store
Expand Down Expand Up @@ -561,25 +622,6 @@ class LowerMatrixIntrinsics {
return true;
}

bool isUniformShape(Value *V) {
Instruction *I = dyn_cast<Instruction>(V);
if (!I)
return true;

switch (I->getOpcode()) {
case Instruction::FAdd:
case Instruction::FSub:
case Instruction::FMul: // Scalar multiply.
case Instruction::FNeg:
case Instruction::Add:
case Instruction::Mul:
case Instruction::Sub:
return true;
default:
return false;
}
}

/// Returns true if shape information can be used for \p V. The supported
/// instructions must match the instructions that can be lowered by this pass.
bool supportsShapeInfo(Value *V) {
Expand Down Expand Up @@ -617,43 +659,8 @@ class LowerMatrixIntrinsics {

// New entry, set the value and insert operands
bool Propagate = false;

Value *MatrixA;
Value *MatrixB;
Value *M;
Value *N;
Value *K;
if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
m_Value(N), m_Value(K)))) {
Propagate = setShapeInfo(Inst, {M, K});
} else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(MatrixA), m_Value(M), m_Value(N)))) {
// Flip dimensions.
Propagate = setShapeInfo(Inst, {N, M});
} else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
m_Value(MatrixA), m_Value(), m_Value(),
m_Value(), m_Value(M), m_Value(N)))) {
Propagate = setShapeInfo(Inst, {N, M});
} else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
m_Value(), m_Value(), m_Value(), m_Value(M),
m_Value(N)))) {
Propagate = setShapeInfo(Inst, {M, N});
} else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
auto OpShape = ShapeMap.find(MatrixA);
if (OpShape != ShapeMap.end())
setShapeInfo(Inst, OpShape->second);
continue;
} else if (isUniformShape(Inst)) {
// Find the first operand that has a known shape and use that.
for (auto &Op : Inst->operands()) {
auto OpShape = ShapeMap.find(Op.get());
if (OpShape != ShapeMap.end()) {
Propagate |= setShapeInfo(Inst, OpShape->second);
break;
}
}
}
if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
Propagate = setShapeInfo(Inst, *SI);

if (Propagate) {
NewWorkList.push_back(Inst);
Expand Down Expand Up @@ -898,20 +905,28 @@ class LowerMatrixIntrinsics {
updateShapeAndReplaceAllUsesWith(I, NewInst);
CleanupBinOp(I, A, B);
}
// A^t + B ^t -> (A + B)^t
// A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
// the shape of the second transpose is different, there's a shape conflict
// which gets resolved by picking the shape of the first operand.
else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) {
m_Value(BT), m_ConstantInt(), m_ConstantInt()))) {
IRBuilder<> Builder(&I);
Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
setShapeInfo(Add, {C, R});
auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd"));
setShapeInfo(Add, {R, C});
MatrixBuilder MBuilder(Builder);
Instruction *NewInst = MBuilder.CreateMatrixTranspose(
Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t");
Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t");
updateShapeAndReplaceAllUsesWith(I, NewInst);
assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
computeShapeInfoForInst(&I, ShapeMap) &&
"Shape of new instruction doesn't match original shape.");
CleanupBinOp(I, A, B);
assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) ==
ShapeMap[Add] &&
"Shape of updated addition doesn't match cached shape.");
}
}

Expand Down
Loading