Skip to content

Commit 0e8717f

Browse files
committed
[Matrix] Add shape verification.
At the moment, lower-matrix-intrinsics accepts mis-matches between shapes for operations. See shape-verification.ll for an example where @llvm.matrix.column.major.load specifies 6x1 and then the use (@llvm.matrix.multiply) specifies the operand to have 1x6. This patch adds verification for shapes to check if shapes match. Reviewed By: thegameg Differential Revision: https://reviews.llvm.org/D147438
1 parent c01ea05 commit 0e8717f

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ static cl::opt<bool> AllowContractEnabled(
7272
cl::desc("Allow the use of FMAs if available and profitable. This may "
7373
"result in different results, due to less rounding error."));
7474

75+
static cl::opt<bool>
76+
VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
77+
cl::desc("Enable/disable matrix shape verification."),
78+
cl::init(false));
79+
7580
enum class MatrixLayoutTy { ColumnMajor, RowMajor };
7681

7782
static cl::opt<MatrixLayoutTy> MatrixLayout(
@@ -535,6 +540,15 @@ class LowerMatrixIntrinsics {
535540

536541
auto SIter = ShapeMap.find(V);
537542
if (SIter != ShapeMap.end()) {
543+
if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
544+
SIter->second.NumColumns != Shape.NumColumns)) {
545+
errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
546+
<< SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
547+
<< Shape.NumColumns << ") for " << *V << "\n";
548+
report_fatal_error(
549+
"Matrix shape verification failed, compilation aborted!");
550+
}
551+
538552
LLVM_DEBUG(dbgs() << " not overriding existing shape: "
539553
<< SIter->second.NumRows << " "
540554
<< SIter->second.NumColumns << " for " << *V << "\n");
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
; RUN: not --crash opt -passes='lower-matrix-intrinsics' -verify-matrix-shapes=true -S %s 2>&1 | FileCheck --check-prefix=VERIFY %s
2+
; RUN: opt -passes='lower-matrix-intrinsics' -verify-matrix-shapes=false -S %s 2>&1 | FileCheck --check-prefix=NOVERIFY %s
3+
4+
; VERIFY: Conflicting shapes (6x1 vs 1x6)
5+
; NOVERIFY-NOT: Conflicting shapes
6+
7+
define <1 x float> @intrinsic_column_major_load_dot_product_float_v6(ptr %lhs_address, ptr %rhs_address) {
8+
entry:
9+
%lhs = tail call fast <6 x float> @llvm.matrix.column.major.load.v6f32.i64(ptr nonnull align 4 %lhs_address, i64 6, i1 false, i32 6, i32 1)
10+
%rhs = tail call fast <6 x float> @llvm.matrix.column.major.load.v6f32.i64(ptr nonnull align 4 %rhs_address, i64 1, i1 false, i32 1, i32 6)
11+
%result = tail call fast <1 x float> @llvm.matrix.multiply.v1f32.v6f32.v6f32(<6 x float> %lhs, <6 x float> %rhs, i32 1, i32 6, i32 1)
12+
ret <1 x float> %result
13+
}
14+
15+
declare <6 x float> @llvm.matrix.column.major.load.v6f32.i64(ptr nonnull align 4, i64, i1, i32, i32)
16+
declare <1 x float> @llvm.matrix.multiply.v1f32.v6f32.v6f32(<6 x float>, <6 x float>, i32, i32, i32)

0 commit comments

Comments
 (0)