Skip to content

Commit 2554941

Browse files
committed
[MLIR] Matrix: support matrix-vector multiplication
This just moves in the implementation from LinearTransform. Reviewed By: Groverkss, bondhugula Differential Revision: https://reviews.llvm.org/D118479
1 parent 0c3d22a commit 2554941

File tree

4 files changed

+38
-26
lines changed

4 files changed

+38
-26
lines changed

mlir/include/mlir/Analysis/Presburger/LinearTransform.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,16 @@ class LinearTransform {
3939

4040
// The given vector is interpreted as a row vector v. Post-multiply v with
4141
// this transform, say T, and return vT.
42-
SmallVector<int64_t, 8> preMultiplyWithRow(ArrayRef<int64_t> rowVec) const;
42+
SmallVector<int64_t, 8> preMultiplyWithRow(ArrayRef<int64_t> rowVec) const {
43+
return matrix.preMultiplyWithRow(rowVec);
44+
}
4345

4446
// The given vector is interpreted as a column vector v. Pre-multiply v with
4547
// this transform, say T, and return Tv.
4648
SmallVector<int64_t, 8>
47-
postMultiplyWithColumn(ArrayRef<int64_t> colVec) const;
49+
postMultiplyWithColumn(ArrayRef<int64_t> colVec) const {
50+
return matrix.postMultiplyWithColumn(colVec);
51+
}
4852

4953
private:
5054
Matrix matrix;

mlir/include/mlir/Analysis/Presburger/Matrix.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,15 @@ class Matrix {
117117
/// Negate the specified column.
118118
void negateColumn(unsigned column);
119119

120+
/// The given vector is interpreted as a row vector v. Post-multiply v with
121+
/// this matrix, say M, and return vM.
122+
SmallVector<int64_t, 8> preMultiplyWithRow(ArrayRef<int64_t> rowVec) const;
123+
124+
/// The given vector is interpreted as a column vector v. Pre-multiply v with
125+
/// this matrix, say M, and return Mv.
126+
SmallVector<int64_t, 8>
127+
postMultiplyWithColumn(ArrayRef<int64_t> colVec) const;
128+
120129
/// Resize the matrix to the specified dimensions. If a dimension is smaller,
121130
/// the values are truncated; if it is bigger, the new values are initialized
122131
/// to zero.

mlir/lib/Analysis/Presburger/LinearTransform.cpp

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,30 +111,6 @@ LinearTransform::makeTransformToColumnEchelon(Matrix m) {
111111
return {echelonCol, LinearTransform(std::move(resultMatrix))};
112112
}
113113

114-
SmallVector<int64_t, 8>
115-
LinearTransform::preMultiplyWithRow(ArrayRef<int64_t> rowVec) const {
116-
assert(rowVec.size() == matrix.getNumRows() &&
117-
"row vector dimension should match transform output dimension");
118-
119-
SmallVector<int64_t, 8> result(matrix.getNumColumns(), 0);
120-
for (unsigned col = 0, e = matrix.getNumColumns(); col < e; ++col)
121-
for (unsigned i = 0, e = matrix.getNumRows(); i < e; ++i)
122-
result[col] += rowVec[i] * matrix(i, col);
123-
return result;
124-
}
125-
126-
SmallVector<int64_t, 8>
127-
LinearTransform::postMultiplyWithColumn(ArrayRef<int64_t> colVec) const {
128-
assert(matrix.getNumColumns() == colVec.size() &&
129-
"column vector dimension should match transform input dimension");
130-
131-
SmallVector<int64_t, 8> result(matrix.getNumRows(), 0);
132-
for (unsigned row = 0, e = matrix.getNumRows(); row < e; row++)
133-
for (unsigned i = 0, e = matrix.getNumColumns(); i < e; i++)
134-
result[row] += matrix(row, i) * colVec[i];
135-
return result;
136-
}
137-
138114
IntegerPolyhedron
139115
LinearTransform::applyTo(const IntegerPolyhedron &poly) const {
140116
IntegerPolyhedron result(poly.getNumIds());

mlir/lib/Analysis/Presburger/Matrix.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,29 @@ void Matrix::negateColumn(unsigned column) {
203203
at(row, column) = -at(row, column);
204204
}
205205

206+
SmallVector<int64_t, 8>
207+
Matrix::preMultiplyWithRow(ArrayRef<int64_t> rowVec) const {
208+
assert(rowVec.size() == getNumRows() && "Invalid row vector dimension!");
209+
210+
SmallVector<int64_t, 8> result(getNumColumns(), 0);
211+
for (unsigned col = 0, e = getNumColumns(); col < e; ++col)
212+
for (unsigned i = 0, e = getNumRows(); i < e; ++i)
213+
result[col] += rowVec[i] * at(i, col);
214+
return result;
215+
}
216+
217+
SmallVector<int64_t, 8>
218+
Matrix::postMultiplyWithColumn(ArrayRef<int64_t> colVec) const {
219+
assert(getNumColumns() == colVec.size() &&
220+
"Invalid column vector dimension!");
221+
222+
SmallVector<int64_t, 8> result(getNumRows(), 0);
223+
for (unsigned row = 0, e = getNumRows(); row < e; row++)
224+
for (unsigned i = 0, e = getNumColumns(); i < e; i++)
225+
result[row] += at(row, i) * colVec[i];
226+
return result;
227+
}
228+
206229
void Matrix::print(raw_ostream &os) const {
207230
for (unsigned row = 0; row < nRows; ++row) {
208231
for (unsigned column = 0; column < nColumns; ++column)

0 commit comments

Comments
 (0)