19
19
#include " flang/Optimizer/HLFIR/HLFIROps.h"
20
20
#include " flang/Optimizer/HLFIR/Passes.h"
21
21
#include " mlir/Dialect/Arith/IR/Arith.h"
22
- #include " mlir/Dialect/Func/IR/FuncOps.h"
23
- #include " mlir/IR/BuiltinDialect.h"
24
22
#include " mlir/IR/Location.h"
25
23
#include " mlir/Pass/Pass.h"
26
- #include " mlir/Transforms/DialectConversion .h"
24
+ #include " mlir/Transforms/GreedyPatternRewriteDriver .h"
27
25
28
26
namespace hlfir {
29
27
#define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS
@@ -45,9 +43,15 @@ class TransposeAsElementalConversion
45
43
llvm::LogicalResult
46
44
matchAndRewrite (hlfir::TransposeOp transpose,
47
45
mlir::PatternRewriter &rewriter) const override {
46
+ hlfir::ExprType expr = transpose.getType ();
47
+ // TODO: hlfir.elemental supports polymorphic data types now,
48
+ // so this can be supported.
49
+ if (expr.isPolymorphic ())
50
+ return rewriter.notifyMatchFailure (transpose,
51
+ " TRANSPOSE of polymorphic type" );
52
+
48
53
mlir::Location loc = transpose.getLoc ();
49
54
fir::FirOpBuilder builder{rewriter, transpose.getOperation ()};
50
- hlfir::ExprType expr = transpose.getType ();
51
55
mlir::Type elementType = expr.getElementType ();
52
56
hlfir::Entity array = hlfir::Entity{transpose.getArray ()};
53
57
mlir::Value resultShape = genResultShape (loc, builder, array);
@@ -105,15 +109,32 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
105
109
llvm::LogicalResult
106
110
matchAndRewrite (hlfir::SumOp sum,
107
111
mlir::PatternRewriter &rewriter) const override {
112
+ if (!simplifySum)
113
+ return rewriter.notifyMatchFailure (sum, " SUM simplification is disabled" );
114
+
115
+ hlfir::Entity array = hlfir::Entity{sum.getArray ()};
116
+ bool isTotalReduction = hlfir::Entity{sum}.getRank () == 0 ;
117
+ mlir::Value dim = sum.getDim ();
118
+ int64_t dimVal = 0 ;
119
+ if (!isTotalReduction) {
120
+ // In case of partial reduction we should ignore the operations
121
+ // with invalid DIM values. They may appear in dead code
122
+ // after constant propagation.
123
+ auto constDim = fir::getIntIfConstant (dim);
124
+ if (!constDim)
125
+ return rewriter.notifyMatchFailure (sum, " Nonconstant DIM for SUM" );
126
+ dimVal = *constDim;
127
+
128
+ if ((dimVal <= 0 || dimVal > array.getRank ()))
129
+ return rewriter.notifyMatchFailure (
130
+ sum, " Invalid DIM for partial SUM reduction" );
131
+ }
132
+
108
133
mlir::Location loc = sum.getLoc ();
109
134
fir::FirOpBuilder builder{rewriter, sum.getOperation ()};
110
135
mlir::Type elementType = hlfir::getFortranElementType (sum.getType ());
111
- hlfir::Entity array = hlfir::Entity{sum.getArray ()};
112
136
mlir::Value mask = sum.getMask ();
113
- mlir::Value dim = sum.getDim ();
114
- bool isTotalReduction = hlfir::Entity{sum}.getRank () == 0 ;
115
- int64_t dimVal =
116
- isTotalReduction ? 0 : fir::getIntIfConstant (dim).value_or (0 );
137
+
117
138
mlir::Value resultShape, dimExtent;
118
139
llvm::SmallVector<mlir::Value> arrayExtents;
119
140
if (isTotalReduction)
@@ -360,27 +381,38 @@ class CShiftAsElementalConversion
360
381
public:
361
382
using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
362
383
363
- explicit CShiftAsElementalConversion (mlir::MLIRContext *ctx)
364
- : OpRewritePattern(ctx) {
365
- setHasBoundedRewriteRecursion ();
366
- }
367
-
368
384
llvm::LogicalResult
369
385
matchAndRewrite (hlfir::CShiftOp cshift,
370
386
mlir::PatternRewriter &rewriter) const override {
371
387
using Fortran::common::maxRank;
372
388
373
- mlir::Location loc = cshift.getLoc ();
374
- fir::FirOpBuilder builder{rewriter, cshift.getOperation ()};
375
389
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType ());
376
390
assert (expr &&
377
391
" expected an expression type for the result of hlfir.cshift" );
392
+ unsigned arrayRank = expr.getRank ();
393
+ // When it is a 1D CSHIFT, we may assume that the DIM argument
394
+ // (whether it is present or absent) is equal to 1, otherwise,
395
+ // the program is illegal.
396
+ int64_t dimVal = 1 ;
397
+ if (arrayRank != 1 )
398
+ if (mlir::Value dim = cshift.getDim ()) {
399
+ auto constDim = fir::getIntIfConstant (dim);
400
+ if (!constDim)
401
+ return rewriter.notifyMatchFailure (cshift,
402
+ " Nonconstant DIM for CSHIFT" );
403
+ dimVal = *constDim;
404
+ }
405
+
406
+ if (dimVal <= 0 || dimVal > arrayRank)
407
+ return rewriter.notifyMatchFailure (cshift, " Invalid DIM for CSHIFT" );
408
+
409
+ mlir::Location loc = cshift.getLoc ();
410
+ fir::FirOpBuilder builder{rewriter, cshift.getOperation ()};
378
411
mlir::Type elementType = expr.getElementType ();
379
412
hlfir::Entity array = hlfir::Entity{cshift.getArray ()};
380
413
mlir::Value arrayShape = hlfir::genShape (loc, builder, array);
381
414
llvm::SmallVector<mlir::Value> arrayExtents =
382
415
hlfir::getExplicitExtentsFromShape (arrayShape, builder);
383
- unsigned arrayRank = expr.getRank ();
384
416
llvm::SmallVector<mlir::Value, 1 > typeParams;
385
417
hlfir::genLengthParameters (loc, builder, array, typeParams);
386
418
hlfir::Entity shift = hlfir::Entity{cshift.getShift ()};
@@ -395,20 +427,6 @@ class CShiftAsElementalConversion
395
427
shiftVal = builder.createConvert (loc, calcType, shiftVal);
396
428
}
397
429
398
- int64_t dimVal = 1 ;
399
- if (arrayRank == 1 ) {
400
- // When it is a 1D CSHIFT, we may assume that the DIM argument
401
- // (whether it is present or absent) is equal to 1, otherwise,
402
- // the program is illegal.
403
- assert (shiftVal && " SHIFT must be scalar" );
404
- } else {
405
- if (mlir::Value dim = cshift.getDim ())
406
- dimVal = fir::getIntIfConstant (dim).value_or (0 );
407
- assert (dimVal > 0 && dimVal <= arrayRank &&
408
- " DIM must be present and a positive constant not exceeding "
409
- " the array's rank" );
410
- }
411
-
412
430
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
413
431
mlir::ValueRange inputIndices) -> hlfir::Entity {
414
432
llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
@@ -462,68 +480,19 @@ class SimplifyHLFIRIntrinsics
462
480
public:
463
481
void runOnOperation () override {
464
482
mlir::MLIRContext *context = &getContext ();
483
+
484
+ mlir::GreedyRewriteConfig config;
485
+ // Prevent the pattern driver from merging blocks
486
+ config.enableRegionSimplification =
487
+ mlir::GreedySimplifyRegionLevel::Disabled;
488
+
465
489
mlir::RewritePatternSet patterns (context);
466
490
patterns.insert <TransposeAsElementalConversion>(context);
467
491
patterns.insert <SumAsElementalConversion>(context);
468
492
patterns.insert <CShiftAsElementalConversion>(context);
469
- mlir::ConversionTarget target (*context);
470
- // don't transform transpose of polymorphic arrays (not currently supported
471
- // by hlfir.elemental)
472
- target.addDynamicallyLegalOp <hlfir::TransposeOp>(
473
- [](hlfir::TransposeOp transpose) {
474
- return mlir::cast<hlfir::ExprType>(transpose.getType ())
475
- .isPolymorphic ();
476
- });
477
- // Handle only SUM(DIM=CONSTANT) case for now.
478
- // It may be beneficial to expand the non-DIM case as well.
479
- // E.g. when the input array is an elemental array expression,
480
- // expanding the SUM into a total reduction loop nest
481
- // would avoid creating a temporary for the elemental array expression.
482
- target.addDynamicallyLegalOp <hlfir::SumOp>([](hlfir::SumOp sum) {
483
- if (!simplifySum)
484
- return true ;
485
-
486
- // Always inline total reductions.
487
- if (hlfir::Entity{sum}.getRank () == 0 )
488
- return false ;
489
- mlir::Value dim = sum.getDim ();
490
- if (!dim)
491
- return false ;
492
-
493
- if (auto dimVal = fir::getIntIfConstant (dim)) {
494
- fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
495
- hlfir::getFortranElementOrSequenceType (sum.getArray ().getType ()));
496
- if (*dimVal > 0 && *dimVal <= arrayTy.getDimension ()) {
497
- // Ignore SUMs with illegal DIM values.
498
- // They may appear in dead code,
499
- // and they do not have to be converted.
500
- return false ;
501
- }
502
- }
503
- return true ;
504
- });
505
- target.addDynamicallyLegalOp <hlfir::CShiftOp>([](hlfir::CShiftOp cshift) {
506
- unsigned resultRank = hlfir::Entity{cshift}.getRank ();
507
- if (resultRank == 1 )
508
- return false ;
509
-
510
- mlir::Value dim = cshift.getDim ();
511
- if (!dim)
512
- return false ;
513
-
514
- // If DIM is present, then it must be constant to please
515
- // the conversion. In addition, ignore cases with
516
- // illegal DIM values.
517
- if (auto dimVal = fir::getIntIfConstant (dim))
518
- if (*dimVal > 0 && *dimVal <= resultRank)
519
- return false ;
520
-
521
- return true ;
522
- });
523
- target.markUnknownOpDynamicallyLegal (
524
- [](mlir::Operation *) { return true ; });
525
- if (mlir::failed (mlir::applyFullConversion (getOperation (), target,
526
- std::move (patterns)))) {
493
+
494
+ if (mlir::failed (mlir::applyPatternsAndFoldGreedily (
495
+ getOperation (), std::move (patterns), config))) {
527
496
mlir::emitError (getOperation ()->getLoc (),
528
497
" failure in HLFIR intrinsic simplification" );
529
498
signalPassFailure ();
0 commit comments