Skip to content

Commit 2d32ee0

Browse files
author
Nicolas Vasilache
committed
[mlir][Vector] Update lowering of vector ops to llvm intrinsics to use row-major.
Summary: LLVM matrix intrinsics recently introduced an option to support row-major mode. This matches the MLIR vector model, this revision switches to row-major. A corner case related to degenerate sizes was also fixed upstream. This revision removes the guard against this corner case. A bug was uncovered on the output vector construction which this revision also fixes. Lastly, this has been tested on a small size and benchmarked independently: no visible performance regression is observed. In the future, when matrix intrinsics support per op attribute, we can more aggressively translate to that and avoid inserting MLIR-level transposes. This has been tested independently to work on small matrices. Differential Revision: https://reviews.llvm.org/D77761
1 parent 00a1032 commit 2d32ee0

File tree

3 files changed

+53
-73
lines changed

3 files changed

+53
-73
lines changed

mlir/include/mlir/Dialect/Vector/VectorOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1446,7 +1446,7 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
14461446
result.addAttribute("lhs_rows", builder->getI32IntegerAttr(lhsRows));
14471447
result.addAttribute("lhs_columns", builder->getI32IntegerAttr(lhsColumns));
14481448
result.addAttribute("rhs_columns", builder->getI32IntegerAttr(rhsColumns));
1449-
result.addTypes(VectorType::get(lhsRows * lhsColumns,
1449+
result.addTypes(VectorType::get(lhsRows * rhsColumns,
14501450
lhs.getType().cast<VectorType>().getElementType()));
14511451
}]>,
14521452
];

mlir/lib/Dialect/Vector/VectorTransforms.cpp

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,43 +1125,34 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
11251125

11261126
// TODO(ntv, ajcbik): implement benefits, cost models, separate this out in
11271127
// a new pattern.
1128-
// TODO(ntv, fhahn): once row-major mode is available in LLVM's matrix
1129-
// intrinsics, use that.
11301128
if (vectorTransformsOptions.lowerToLLVMMatrixIntrinsics &&
1131-
isColumnMajorMatmul(op.indexing_maps())) {
1129+
isRowMajorMatmul(op.indexing_maps())) {
11321130
VectorType lhsType = op.getLhsType();
11331131
VectorType rhsType = op.getRhsType();
11341132
unsigned lhsRows = op.getLhsType().getShape()[0];
11351133
unsigned lhsColumns = op.getLhsType().getShape()[1];
11361134
unsigned rhsColumns = op.getRhsType().getShape()[1];
11371135

1138-
// In cases where matrices are degenerate, scalarization issues occur in
1139-
// the backend. Avoid all LLVM scalarization issues for now.
1140-
// For more details, see: https://bugs.llvm.org/show_bug.cgi?id=45227 and
1141-
// https://bugs.llvm.org/show_bug.cgi?id=45229
1142-
// TODO(ntv, fhahn): Relax once above bugs are fixed.
1143-
if (lhsRows != 1 && lhsColumns != 1 && rhsColumns != 1) {
1144-
Type flattenedLHSType =
1145-
VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1146-
Type flattenedRHSType =
1147-
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1148-
auto lhs = rewriter.create<vector::ShapeCastOp>(
1149-
op.getLoc(), flattenedLHSType, op.lhs());
1150-
auto rhs = rewriter.create<vector::ShapeCastOp>(
1151-
op.getLoc(), flattenedRHSType, op.rhs());
1152-
1153-
Value mul = rewriter.create<vector::MatmulOp>(
1154-
op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns);
1155-
mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(),
1156-
op.acc().getType(), mul);
1157-
Type elementType = op.getLhsType().getElementType();
1158-
assert(elementType.isIntOrFloat());
1159-
if (elementType.isa<IntegerType>())
1160-
rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
1161-
else
1162-
rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
1163-
return success();
1164-
}
1136+
Type flattenedLHSType =
1137+
VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
1138+
Type flattenedRHSType =
1139+
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
1140+
auto lhs = rewriter.create<vector::ShapeCastOp>(
1141+
op.getLoc(), flattenedLHSType, op.lhs());
1142+
auto rhs = rewriter.create<vector::ShapeCastOp>(
1143+
op.getLoc(), flattenedRHSType, op.rhs());
1144+
1145+
Value mul = rewriter.create<vector::MatmulOp>(
1146+
op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns);
1147+
mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(),
1148+
op.acc().getType(), mul);
1149+
Type elementType = op.getLhsType().getElementType();
1150+
assert(elementType.isIntOrFloat());
1151+
if (elementType.isa<IntegerType>())
1152+
rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
1153+
else
1154+
rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
1155+
return success();
11651156
}
11661157

11671158
// Find first batch dimension in LHS/RHS, and lower when found.

mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -357,46 +357,35 @@ func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
357357
return %r0, %1 : vector<4xf32>, vector<2x2xf32>
358358
}
359359

360-
// MATRIX-LABEL: func @column_major_matmul
361-
// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x3xf32>,
362-
// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<2x4xf32>,
363-
// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
364-
// MATRIX: %[[vcst:.*]] = constant dense<0.000000e+00> : vector<12xf32>
365-
// MATRIX: %[[vcst_0:.*]] = constant dense<0.000000e+00> : vector<8xf32>
366-
// MATRIX: %[[vcst_1:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>
367-
// MATRIX: %[[a0:.*]] = vector.extract %[[A]][0] : vector<4x3xf32>
368-
// MATRIX: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32>
369-
// MATRIX: %[[a2:.*]] = vector.extract %[[A]][1] : vector<4x3xf32>
370-
// MATRIX: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32>
371-
// MATRIX: %[[a4:.*]] = vector.extract %[[A]][2] : vector<4x3xf32>
372-
// MATRIX: %[[a5:.*]] = vector.insert_strided_slice %[[a4]], %[[a3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32>
373-
// MATRIX: %[[a6:.*]] = vector.extract %[[A]][3] : vector<4x3xf32>
374-
// MATRIX: %[[a7:.*]] = vector.insert_strided_slice %[[a6]], %[[a5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32>
375-
// MATRIX: %[[b8:.*]] = vector.extract %[[B]][0] : vector<2x4xf32>
376-
// MATRIX: %[[b9:.*]] = vector.insert_strided_slice %[[b8]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32>
377-
// MATRIX: %[[b10:.*]] = vector.extract %[[B]][1] : vector<2x4xf32>
378-
// MATRIX: %[[b11:.*]] = vector.insert_strided_slice %[[b10]], %[[b9]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32>
379-
// MATRIX: %[[mm12:.*]] = vector.matrix_multiply %[[a7]], %[[b11]] {lhs_columns = 3 : i32, lhs_rows = 4 : i32, rhs_columns = 4 : i32} : (vector<12xf32>, vector<8xf32>) -> vector<12xf32>
380-
// MATRIX: %[[mm13:.*]] = vector.strided_slice %[[mm12]] {offsets = [0], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32>
381-
// MATRIX: %[[mm14:.*]] = vector.insert %[[mm13]], %[[vcst_1]] [0] : vector<2xf32> into vector<3x2xf32>
382-
// MATRIX: %[[mm15:.*]] = vector.strided_slice %[[mm12]] {offsets = [2], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32>
383-
// MATRIX: %[[mm16:.*]] = vector.insert %[[mm15]], %[[mm14]] [1] : vector<2xf32> into vector<3x2xf32>
384-
// MATRIX: %[[mm17:.*]] = vector.strided_slice %[[mm12]] {offsets = [4], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32>
385-
// MATRIX: %[[mm18:.*]] = vector.insert %[[mm17]], %[[mm16]] [2] : vector<2xf32> into vector<3x2xf32>
386-
// MATRIX: %[[mm19:.*]] = addf %[[C]], %[[mm18]] : vector<3x2xf32>
387-
#column_major_matmat_accesses = [
388-
affine_map<(i, j, k) -> (k, j)>,
389-
affine_map<(i, j, k) -> (i, k)>,
390-
affine_map<(i, j, k) -> (j, i)>
391-
]
392-
#column_major_matmat_trait = {
393-
indexing_maps = #column_major_matmat_accesses,
394-
iterator_types = ["parallel", "parallel", "reduction"]
395-
}
396-
func @column_major_matmul(%arg0: vector<4x3xf32>,
397-
%arg1: vector<2x4xf32>,
398-
%arg2: vector<3x2xf32>) -> vector<3x2xf32> {
399-
%0 = vector.contract #column_major_matmat_trait %arg0, %arg1, %arg2
400-
: vector<4x3xf32>, vector<2x4xf32> into vector<3x2xf32>
401-
return %0 : vector<3x2xf32>
360+
// MATRIX-LABEL: func @matmul
361+
// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
362+
// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
363+
// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
364+
// MATRIX: %[[vcst:.*]] = constant dense<0.000000e+00> : vector<8xf32>
365+
// MATRIX: %[[vcst_0:.*]] = constant dense<0.000000e+00> : vector<12xf32>
366+
// MATRIX: %[[vcst_1:.*]] = constant dense<0.000000e+00> : vector<2x3xf32>
367+
// MATRIX: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32>
368+
// MATRIX: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32>
369+
// MATRIX: %[[a2:.*]] = vector.extract %[[A]][1] : vector<2x4xf32>
370+
// MATRIX: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32>
371+
// MATRIX: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32>
372+
// MATRIX: %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32>
373+
// MATRIX: %[[b2:.*]] = vector.extract %[[B]][1] : vector<4x3xf32>
374+
// MATRIX: %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32>
375+
// MATRIX: %[[b4:.*]] = vector.extract %[[B]][2] : vector<4x3xf32>
376+
// MATRIX: %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32>
377+
// MATRIX: %[[b6:.*]] = vector.extract %[[B]][3] : vector<4x3xf32>
378+
// MATRIX: %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32>
379+
// MATRIX: %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32>
380+
// MATRIX: %[[mm2:.*]] = vector.strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
381+
// MATRIX: %[[mm3:.*]] = vector.insert %[[mm2]], %[[vcst_1]] [0] : vector<3xf32> into vector<2x3xf32>
382+
// MATRIX: %[[mm4:.*]] = vector.strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
383+
// MATRIX: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32>
384+
// MATRIX: %[[mm6:.*]] = addf %[[C]], %[[mm5]] : vector<2x3xf32>
385+
func @matmul(%arg0: vector<2x4xf32>,
386+
%arg1: vector<4x3xf32>,
387+
%arg2: vector<2x3xf32>) -> vector<2x3xf32> {
388+
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
389+
: vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
390+
return %0 : vector<2x3xf32>
402391
}

0 commit comments

Comments
 (0)