Skip to content

Commit e55aa02

Browse files
authored
[flang] Fix runtime error messages for the MATMUL intrinsic (#96928)
There are three forms of MATMUL -- where the first argument is a rank 1 array, where the second argument is a rank 1 array, and where both arguments are rank 2 arrays. There's code in the runtime that detects when the array shapes are incorrect. But the code that emits an error message assumes that both arguments are rank 2 arrays. This change contains code for the other two cases.
1 parent 5b36348 commit e55aa02

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

flang/runtime/matmul.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,11 +288,25 @@ static inline RT_API_ATTRS void DoMatmul(
288288
}
289289
SubscriptValue n{x.GetDimension(xRank - 1).Extent()};
290290
if (n != y.GetDimension(0).Extent()) {
291-
terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
292-
static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
293-
static_cast<std::intmax_t>(n),
294-
static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
295-
static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
291+
// At this point, we know that there's a shape error. There are three
292+
// possibilities, x is rank 1, y is rank 1, or both are rank 2.
293+
if (xRank == 1) {
294+
terminator.Crash("MATMUL: unacceptable operand shapes (%jd, %jdx%jd)",
295+
static_cast<std::intmax_t>(n),
296+
static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
297+
static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
298+
} else if (yRank == 1) {
299+
terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jd)",
300+
static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
301+
static_cast<std::intmax_t>(n),
302+
static_cast<std::intmax_t>(y.GetDimension(0).Extent()));
303+
} else {
304+
terminator.Crash("MATMUL: unacceptable operand shapes (%jdx%jd, %jdx%jd)",
305+
static_cast<std::intmax_t>(x.GetDimension(0).Extent()),
306+
static_cast<std::intmax_t>(n),
307+
static_cast<std::intmax_t>(y.GetDimension(0).Extent()),
308+
static_cast<std::intmax_t>(y.GetDimension(1).Extent()));
309+
}
296310
}
297311
using WriteResult =
298312
CppTypeFor<RCAT == TypeCategory::Logical ? TypeCategory::Integer : RCAT,

0 commit comments

Comments
 (0)