Skip to content

Commit 0fdf912

Browse files
authored
[flang] Fold MATMUL() (#72176)
Implements constant folding for matrix multiplication for all four accepted type categories.
1 parent 2602d88 commit 0fdf912

File tree

7 files changed

+158
-6
lines changed

7 files changed

+158
-6
lines changed

flang/lib/Evaluate/fold-complex.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "fold-implementation.h"
10+
#include "fold-matmul.h"
1011
#include "fold-reduction.h"
1112

1213
namespace Fortran::evaluate {
@@ -64,13 +65,14 @@ Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
6465
}
6566
} else if (name == "dot_product") {
6667
return FoldDotProduct<T>(context, std::move(funcRef));
68+
} else if (name == "matmul") {
69+
return FoldMatmul(context, std::move(funcRef));
6770
} else if (name == "product") {
6871
auto one{Scalar<Part>::FromInteger(value::Integer<8>{1}).value};
6972
return FoldProduct<T>(context, std::move(funcRef), Scalar<T>{one});
7073
} else if (name == "sum") {
7174
return FoldSum<T>(context, std::move(funcRef));
7275
}
73-
// TODO: matmul
7476
return Expr<T>{std::move(funcRef)};
7577
}
7678

flang/lib/Evaluate/fold-integer.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "fold-implementation.h"
10+
#include "fold-matmul.h"
1011
#include "fold-reduction.h"
1112
#include "flang/Evaluate/check-expression.h"
1213

@@ -1042,6 +1043,8 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
10421043
ScalarFunc<T, Int4>([&fptr](const Scalar<Int4> &places) -> Scalar<T> {
10431044
return fptr(static_cast<int>(places.ToInt64()));
10441045
}));
1046+
} else if (name == "matmul") {
1047+
return FoldMatmul(context, std::move(funcRef));
10451048
} else if (name == "max") {
10461049
return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
10471050
} else if (name == "max0" || name == "max1") {
@@ -1279,7 +1282,6 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
12791282
} else if (name == "ubound") {
12801283
return UBOUND(context, std::move(funcRef));
12811284
}
1282-
// TODO: dot_product, matmul, sign
12831285
return Expr<T>{std::move(funcRef)};
12841286
}
12851287

flang/lib/Evaluate/fold-logical.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "fold-implementation.h"
10+
#include "fold-matmul.h"
1011
#include "fold-reduction.h"
1112
#include "flang/Evaluate/check-expression.h"
1213
#include "flang/Runtime/magic-numbers.h"
@@ -231,6 +232,8 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
231232
if (auto *expr{UnwrapExpr<Expr<SomeLogical>>(args[0])}) {
232233
return Fold(context, ConvertToType<T>(std::move(*expr)));
233234
}
235+
} else if (name == "matmul") {
236+
return FoldMatmul(context, std::move(funcRef));
234237
} else if (name == "out_of_range") {
235238
if (Expr<SomeType> * cx{UnwrapExpr<Expr<SomeType>>(args[0])}) {
236239
auto restorer{context.messages().DiscardMessages()};
@@ -367,7 +370,6 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
367370
name == "__builtin_ieee_support_underflow_control") {
368371
return Expr<T>{true};
369372
}
370-
// TODO: logical, matmul, parity
371373
return Expr<T>{std::move(funcRef)};
372374
}
373375

flang/lib/Evaluate/fold-matmul.h

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
//===-- lib/Evaluate/fold-matmul.h ----------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef FORTRAN_EVALUATE_FOLD_MATMUL_H_
10+
#define FORTRAN_EVALUATE_FOLD_MATMUL_H_
11+
12+
#include "fold-implementation.h"
13+
14+
namespace Fortran::evaluate {
15+
16+
template <typename T>
17+
static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
18+
using Element = typename Constant<T>::Element;
19+
auto args{funcRef.arguments()};
20+
CHECK(args.size() == 2);
21+
Folder<T> folder{context};
22+
Constant<T> *ma{folder.Folding(args[0])};
23+
Constant<T> *mb{folder.Folding(args[1])};
24+
if (!ma || !mb) {
25+
return Expr<T>{std::move(funcRef)};
26+
}
27+
CHECK(ma->Rank() >= 1 && ma->Rank() <= 2 && mb->Rank() >= 1 &&
28+
mb->Rank() <= 2 && (ma->Rank() == 2 || mb->Rank() == 2));
29+
ConstantSubscript commonExtent{ma->shape().back()};
30+
if (mb->shape().front() != commonExtent) {
31+
context.messages().Say(
32+
"Arguments to MATMUL have distinct extents %zd and %zd on their last and first dimensions"_err_en_US,
33+
commonExtent, mb->shape().front());
34+
return MakeInvalidIntrinsic(std::move(funcRef));
35+
}
36+
ConstantSubscript rows{ma->Rank() == 1 ? 1 : ma->shape()[0]};
37+
ConstantSubscript columns{mb->Rank() == 1 ? 1 : mb->shape()[1]};
38+
std::vector<Element> elements;
39+
elements.reserve(rows * columns);
40+
bool overflow{false};
41+
[[maybe_unused]] const auto &rounding{
42+
context.targetCharacteristics().roundingMode()};
43+
// result(j,k) = SUM(A(j,:) * B(:,k))
44+
for (ConstantSubscript ci{0}; ci < columns; ++ci) {
45+
for (ConstantSubscript ri{0}; ri < rows; ++ri) {
46+
ConstantSubscripts aAt{ma->lbounds()};
47+
if (ma->Rank() == 2) {
48+
aAt[0] += ri;
49+
}
50+
ConstantSubscripts bAt{mb->lbounds()};
51+
if (mb->Rank() == 2) {
52+
bAt[1] += ci;
53+
}
54+
Element sum{};
55+
[[maybe_unused]] Element correction{};
56+
for (ConstantSubscript j{0}; j < commonExtent; ++j) {
57+
Element aElt{ma->At(aAt)};
58+
Element bElt{mb->At(bAt)};
59+
if constexpr (T::category == TypeCategory::Real ||
60+
T::category == TypeCategory::Complex) {
61+
// Kahan summation
62+
auto product{aElt.Multiply(bElt, rounding)};
63+
overflow |= product.flags.test(RealFlag::Overflow);
64+
auto next{correction.Add(product.value, rounding)};
65+
overflow |= next.flags.test(RealFlag::Overflow);
66+
auto added{sum.Add(next.value, rounding)};
67+
overflow |= added.flags.test(RealFlag::Overflow);
68+
correction = added.value.Subtract(sum, rounding)
69+
.value.Subtract(next.value, rounding)
70+
.value;
71+
sum = std::move(added.value);
72+
} else if constexpr (T::category == TypeCategory::Integer) {
73+
auto product{aElt.MultiplySigned(bElt)};
74+
overflow |= product.SignedMultiplicationOverflowed();
75+
auto added{sum.AddSigned(product.lower)};
76+
overflow |= added.overflow;
77+
sum = std::move(added.value);
78+
} else {
79+
static_assert(T::category == TypeCategory::Logical);
80+
sum = sum.OR(aElt.AND(bElt));
81+
}
82+
++aAt.back();
83+
++bAt.front();
84+
}
85+
elements.push_back(sum);
86+
}
87+
}
88+
if (overflow) {
89+
context.messages().Say(
90+
"MATMUL of %s data overflowed during computation"_warn_en_US,
91+
T::AsFortran());
92+
}
93+
ConstantSubscripts shape;
94+
if (ma->Rank() == 2) {
95+
shape.push_back(rows);
96+
}
97+
if (mb->Rank() == 2) {
98+
shape.push_back(columns);
99+
}
100+
return Expr<T>{Constant<T>{std::move(elements), std::move(shape)}};
101+
}
102+
} // namespace Fortran::evaluate
103+
#endif // FORTRAN_EVALUATE_FOLD_MATMUL_H_

flang/lib/Evaluate/fold-real.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "fold-implementation.h"
10+
#include "fold-matmul.h"
1011
#include "fold-reduction.h"
1112

1213
namespace Fortran::evaluate {
@@ -269,6 +270,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
269270
}
270271
return result.value;
271272
}));
273+
} else if (name == "matmul") {
274+
return FoldMatmul(context, std::move(funcRef));
272275
} else if (name == "max") {
273276
return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater);
274277
} else if (name == "maxval") {
@@ -446,7 +449,6 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
446449
return result.value;
447450
}));
448451
}
449-
// TODO: matmul
450452
return Expr<T>{std::move(funcRef)};
451453
}
452454

flang/lib/Evaluate/fold-reduction.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ static Expr<T> FoldDotProduct(
4343
Expr<T> products{Fold(
4444
context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
4545
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
46-
Element correction; // Use Kahan summation for greater precision.
46+
Element correction{}; // Use Kahan summation for greater precision.
4747
const auto &rounding{context.targetCharacteristics().roundingMode()};
4848
for (const Element &x : cProducts.values()) {
4949
auto next{correction.Add(x, rounding)};
@@ -80,7 +80,7 @@ static Expr<T> FoldDotProduct(
8080
Expr<T> products{
8181
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
8282
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
83-
Element correction; // Use Kahan summation for greater precision.
83+
Element correction{}; // Use Kahan summation for greater precision.
8484
const auto &rounding{context.targetCharacteristics().roundingMode()};
8585
for (const Element &x : cProducts.values()) {
8686
auto next{correction.Add(x, rounding)};

flang/test/Evaluate/fold-matmul.f90

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
! RUN: %python %S/test_folding.py %s %flang_fc1
2+
! Tests folding of MATMUL()
3+
module m
4+
integer, parameter :: ia(2,3) = reshape([1, 2, 2, 3, 3, 4], shape(ia))
5+
integer, parameter :: ib(3,2) = reshape([1, 2, 3, 2, 3, 4], shape(ib))
6+
integer, parameter :: ix(*) = [1, 2]
7+
integer, parameter :: iy(*) = [1, 2, 3]
8+
integer, parameter :: iab(*,*) = matmul(ia, ib)
9+
integer, parameter :: ixa(*) = matmul(ix, ia)
10+
integer, parameter :: iay(*) = matmul(ia, iy)
11+
logical, parameter :: test_iab = all([iab] == [14, 20, 20, 29])
12+
logical, parameter :: test_ixa = all(ixa == [5, 8, 11])
13+
logical, parameter :: test_iay = all(iay == [14, 20])
14+
15+
real, parameter :: ra(*,*) = ia
16+
real, parameter :: rb(*,*) = ib
17+
real, parameter :: rx(*) = ix
18+
real, parameter :: ry(*) = iy
19+
real, parameter :: rab(*,*) = matmul(ra, rb)
20+
real, parameter :: rxa(*) = matmul(rx, ra)
21+
real, parameter :: ray(*) = matmul(ra, ry)
22+
logical, parameter :: test_rab = all(rab == iab)
23+
logical, parameter :: test_rxa = all(rxa == ixa)
24+
logical, parameter :: test_ray = all(ray == iay)
25+
26+
complex, parameter :: za(*,*) = cmplx(ra, -1.)
27+
complex, parameter :: zb(*,*) = cmplx(rb, -1.)
28+
complex, parameter :: zx(*) = cmplx(rx, -1.)
29+
complex, parameter :: zy(*) = cmplx(ry, -1.)
30+
complex, parameter :: zab(*,*) = matmul(za, zb)
31+
complex, parameter :: zxa(*) = matmul(zx, za)
32+
complex, parameter :: zay(*) = matmul(za, zy)
33+
logical, parameter :: test_zab = all([zab] == [(11,-12),(17,-15),(17,-15),(26,-18)])
34+
logical, parameter :: test_zxa = all(zxa == [(3,-6),(6,-8),(9,-10)])
35+
logical, parameter :: test_zay = all(zay == [(11,-12),(17,-15)])
36+
37+
logical, parameter :: la(16, 4) = reshape([((iand(shiftr(j,k),1)/=0, j=0,15), k=0,3)], shape(la))
38+
logical, parameter :: lb(4, 16) = transpose(la)
39+
logical, parameter :: lab(16, 16) = matmul(la, lb)
40+
logical, parameter :: test_lab = all([lab] .eqv. [((iand(k,j)/=0, k=0,15), j=0,15)])
41+
end

0 commit comments

Comments
 (0)