Skip to content

Commit 1707b84

Browse files
committed
[flang] Avoid optimizing min and max if not valid type
In `makeMinMaxInitValGenerator` it explicitly checks for only `FloatType` and `IntegerType`, so we shouldn't match if we don't have either of those types.
1 parent 19e0233 commit 1707b84

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,8 +988,15 @@ class ReductionConversion : public mlir::OpRewritePattern<Op> {
988988
op, "Currently minloc/maxloc is not handled");
989989
} else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> ||
990990
std::is_same_v<Op, hlfir::MinvalOp>) {
991+
mlir::Type ty = op.getType();
992+
if (!(mlir::isa<mlir::FloatType>(ty) ||
993+
mlir::isa<mlir::IntegerType>(ty))) {
994+
return rewriter.notifyMatchFailure(
995+
op, "Type is not supported for Maxval or Minval yet");
996+
}
997+
991998
bool isMax = std::is_same_v<Op, hlfir::MaxvalOp>;
992-
init = makeMinMaxInitValGenerator(isMax)(builder, loc, op.getType());
999+
init = makeMinMaxInitValGenerator(isMax)(builder, loc, ty);
9931000
genBodyFn = [inlineSource, isMax](
9941001
fir::FirOpBuilder builder, mlir::Location loc,
9951002
mlir::Value reduction,
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Test maxval inlining for both elemental and designate
2+
// RUN: fir-opt %s -opt-bufferization | FileCheck %s
3+
4+
func.func @_QQmain() {
5+
%c8 = arith.constant 8 : index
6+
%1 = fir.alloca !fir.char<1, 8>
7+
%24 = fir.shape %c8 : (index) -> !fir.shape<1>
8+
%25 = hlfir.elemental %24 typeparams %c8 unordered : (!fir.shape<1>, index) -> !hlfir.expr<?x!fir.char<1, 8>> {
9+
^bb0(%arg0: index):
10+
%dummy = fir.string_lit "A"(8) : !fir.char<1, 8>
11+
hlfir.yield_element %dummy : !fir.char<1, 8>
12+
}
13+
%26 = hlfir.maxval %25 {fastmath = #arith.fastmath<contract>} : (!hlfir.expr<?x!fir.char<1, 8>>) -> !hlfir.expr<!fir.char<1, 8>>
14+
hlfir.assign %26 to %1 : !hlfir.expr<!fir.char<1, 8>>, !fir.ref<!fir.char<1, 8>> // Assign to %1 directly
15+
hlfir.destroy %26 : !hlfir.expr<!fir.char<1, 8>>
16+
hlfir.destroy %25 : !hlfir.expr<?x!fir.char<1, 8>>
17+
return
18+
}
19+
20+
// CHECK-LABEL: func.func @_QQmain() {
21+
// CHECK-NEXT: %c8 = arith.constant 8 : index
22+
// CHECK-NEXT: %[[V0:.*]] = fir.alloca !fir.char<1,8>
23+
// CHECK-NEXT: %[[V1:.*]] = fir.shape %c8 : (index) -> !fir.shape<1>
24+
// CHECK-NEXT: %[[V2:.*]] = hlfir.elemental %1 typeparams %c8 unordered : (!fir.shape<1>, index) -> !hlfir.expr<?x!fir.char<1,8>> {
25+
// CHECK-NEXT: ^bb0(%arg0: index):
26+
// CHECK-NEXT: %[[V4:.*]] = fir.string_lit "A"(8) : !fir.char<1,8>
27+
// CHECK-NEXT: hlfir.yield_element %[[V4]] : !fir.char<1,8>
28+
// CHECK-NEXT: }
29+
// CHECK-NEXT: %[[V3:.*]] = hlfir.maxval %[[V2]] {fastmath = #arith.fastmath<contract>} : (!hlfir.expr<?x!fir.char<1,8>>) -> !hlfir.expr<!fir.char<1,8>>
30+
// CHECK-NEXT: hlfir.assign %[[V3]] to %[[V0]] : !hlfir.expr<!fir.char<1,8>>, !fir.ref<!fir.char<1,8>>
31+
// CHECK-NEXT: hlfir.destroy %[[V3]] : !hlfir.expr<!fir.char<1,8>>
32+
// CHECK-NEXT: hlfir.destroy %[[V2]] : !hlfir.expr<?x!fir.char<1,8>>
33+
// CHECK-NEXT: return
34+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)