Skip to content

[flang] Turn SimplifyHLFIRIntrinsics into a greedy rewriter. #119946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 16, 2024

Conversation

vzakhari
Copy link
Contributor

This is almost an NFC, except that folding changed ordering
of some operations.

This is almost an NFC, except that folding changed ordering
of some operations.
@vzakhari vzakhari requested review from tblah and jeanPerier December 14, 2024 03:22
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Dec 14, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 14, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Slava Zakharin (vzakhari)

Changes

This is almost an NFC, except that folding changed ordering
of some operations.


Patch is 80.63 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119946.diff

4 Files Affected:

  • (modified) flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp (+58-89)
  • (modified) flang/test/HLFIR/simplify-hlfir-intrinsics-cshift.fir (+58-84)
  • (modified) flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir (+100-126)
  • (modified) flang/test/HLFIR/simplify-hlfir-intrinsics.fir (+26-32)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 3e9d956b6e56dd..3bccf25865c727 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -19,11 +19,9 @@
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Location.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace hlfir {
 #define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS
@@ -44,9 +42,15 @@ class TransposeAsElementalConversion
   llvm::LogicalResult
   matchAndRewrite(hlfir::TransposeOp transpose,
                   mlir::PatternRewriter &rewriter) const override {
+    hlfir::ExprType expr = transpose.getType();
+    // TODO: hlfir.elemental supports polymorphic data types now,
+    // so this can be supported.
+    if (expr.isPolymorphic())
+      return rewriter.notifyMatchFailure(transpose,
+                                         "TRANSPOSE of polymorphic type");
+
     mlir::Location loc = transpose.getLoc();
     fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
-    hlfir::ExprType expr = transpose.getType();
     mlir::Type elementType = expr.getElementType();
     hlfir::Entity array = hlfir::Entity{transpose.getArray()};
     mlir::Value resultShape = genResultShape(loc, builder, array);
@@ -104,15 +108,32 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
   llvm::LogicalResult
   matchAndRewrite(hlfir::SumOp sum,
                   mlir::PatternRewriter &rewriter) const override {
+    if (!simplifySum)
+      return rewriter.notifyMatchFailure(sum, "SUM simplification is disabled");
+
+    hlfir::Entity array = hlfir::Entity{sum.getArray()};
+    bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
+    mlir::Value dim = sum.getDim();
+    int64_t dimVal = 0;
+    if (!isTotalReduction) {
+      // In case of partial reduction we should ignore the operations
+      // with invalid DIM values. They may appear in dead code
+      // after constant propagation.
+      auto constDim = fir::getIntIfConstant(dim);
+      if (!constDim)
+        return rewriter.notifyMatchFailure(sum, "Nonconstant DIM for SUM");
+      dimVal = *constDim;
+
+      if ((dimVal <= 0 || dimVal > array.getRank()))
+        return rewriter.notifyMatchFailure(
+            sum, "Invalid DIM for partial SUM reduction");
+    }
+
     mlir::Location loc = sum.getLoc();
     fir::FirOpBuilder builder{rewriter, sum.getOperation()};
     mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
-    hlfir::Entity array = hlfir::Entity{sum.getArray()};
     mlir::Value mask = sum.getMask();
-    mlir::Value dim = sum.getDim();
-    bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
-    int64_t dimVal =
-        isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);
+
     mlir::Value resultShape, dimExtent;
     llvm::SmallVector<mlir::Value> arrayExtents;
     if (isTotalReduction)
@@ -359,27 +380,38 @@ class CShiftAsElementalConversion
 public:
   using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
 
-  explicit CShiftAsElementalConversion(mlir::MLIRContext *ctx)
-      : OpRewritePattern(ctx) {
-    setHasBoundedRewriteRecursion();
-  }
-
   llvm::LogicalResult
   matchAndRewrite(hlfir::CShiftOp cshift,
                   mlir::PatternRewriter &rewriter) const override {
     using Fortran::common::maxRank;
 
-    mlir::Location loc = cshift.getLoc();
-    fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
     hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType());
     assert(expr &&
            "expected an expression type for the result of hlfir.cshift");
+    unsigned arrayRank = expr.getRank();
+    // When it is a 1D CSHIFT, we may assume that the DIM argument
+    // (whether it is present or absent) is equal to 1, otherwise,
+    // the program is illegal.
+    int64_t dimVal = 1;
+    if (arrayRank != 1)
+      if (mlir::Value dim = cshift.getDim()) {
+        auto constDim = fir::getIntIfConstant(dim);
+        if (!constDim)
+          return rewriter.notifyMatchFailure(cshift,
+                                             "Nonconstant DIM for CSHIFT");
+        dimVal = *constDim;
+      }
+
+    if (dimVal <= 0 || dimVal > arrayRank)
+      return rewriter.notifyMatchFailure(cshift, "Invalid DIM for CSHIFT");
+
+    mlir::Location loc = cshift.getLoc();
+    fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
     mlir::Type elementType = expr.getElementType();
     hlfir::Entity array = hlfir::Entity{cshift.getArray()};
     mlir::Value arrayShape = hlfir::genShape(loc, builder, array);
     llvm::SmallVector<mlir::Value> arrayExtents =
         hlfir::getExplicitExtentsFromShape(arrayShape, builder);
-    unsigned arrayRank = expr.getRank();
     llvm::SmallVector<mlir::Value, 1> typeParams;
     hlfir::genLengthParameters(loc, builder, array, typeParams);
     hlfir::Entity shift = hlfir::Entity{cshift.getShift()};
@@ -394,20 +426,6 @@ class CShiftAsElementalConversion
       shiftVal = builder.createConvert(loc, calcType, shiftVal);
     }
 
-    int64_t dimVal = 1;
-    if (arrayRank == 1) {
-      // When it is a 1D CSHIFT, we may assume that the DIM argument
-      // (whether it is present or absent) is equal to 1, otherwise,
-      // the program is illegal.
-      assert(shiftVal && "SHIFT must be scalar");
-    } else {
-      if (mlir::Value dim = cshift.getDim())
-        dimVal = fir::getIntIfConstant(dim).value_or(0);
-      assert(dimVal > 0 && dimVal <= arrayRank &&
-             "DIM must be present and a positive constant not exceeding "
-             "the array's rank");
-    }
-
     auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
                          mlir::ValueRange inputIndices) -> hlfir::Entity {
       llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
@@ -461,68 +479,19 @@ class SimplifyHLFIRIntrinsics
 public:
   void runOnOperation() override {
     mlir::MLIRContext *context = &getContext();
+
+    mlir::GreedyRewriteConfig config;
+    // Prevent the pattern driver from merging blocks
+    config.enableRegionSimplification =
+        mlir::GreedySimplifyRegionLevel::Disabled;
+
     mlir::RewritePatternSet patterns(context);
     patterns.insert<TransposeAsElementalConversion>(context);
     patterns.insert<SumAsElementalConversion>(context);
     patterns.insert<CShiftAsElementalConversion>(context);
-    mlir::ConversionTarget target(*context);
-    // don't transform transpose of polymorphic arrays (not currently supported
-    // by hlfir.elemental)
-    target.addDynamicallyLegalOp<hlfir::TransposeOp>(
-        [](hlfir::TransposeOp transpose) {
-          return mlir::cast<hlfir::ExprType>(transpose.getType())
-              .isPolymorphic();
-        });
-    // Handle only SUM(DIM=CONSTANT) case for now.
-    // It may be beneficial to expand the non-DIM case as well.
-    // E.g. when the input array is an elemental array expression,
-    // expanding the SUM into a total reduction loop nest
-    // would avoid creating a temporary for the elemental array expression.
-    target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
-      if (!simplifySum)
-        return true;
-
-      // Always inline total reductions.
-      if (hlfir::Entity{sum}.getRank() == 0)
-        return false;
-      mlir::Value dim = sum.getDim();
-      if (!dim)
-        return false;
-
-      if (auto dimVal = fir::getIntIfConstant(dim)) {
-        fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
-            hlfir::getFortranElementOrSequenceType(sum.getArray().getType()));
-        if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
-          // Ignore SUMs with illegal DIM values.
-          // They may appear in dead code,
-          // and they do not have to be converted.
-          return false;
-        }
-      }
-      return true;
-    });
-    target.addDynamicallyLegalOp<hlfir::CShiftOp>([](hlfir::CShiftOp cshift) {
-      unsigned resultRank = hlfir::Entity{cshift}.getRank();
-      if (resultRank == 1)
-        return false;
-
-      mlir::Value dim = cshift.getDim();
-      if (!dim)
-        return false;
-
-      // If DIM is present, then it must be constant to please
-      // the conversion. In addition, ignore cases with
-      // illegal DIM values.
-      if (auto dimVal = fir::getIntIfConstant(dim))
-        if (*dimVal > 0 && *dimVal <= resultRank)
-          return false;
-
-      return true;
-    });
-    target.markUnknownOpDynamicallyLegal(
-        [](mlir::Operation *) { return true; });
-    if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
-                                               std::move(patterns)))) {
+
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+            getOperation(), std::move(patterns), config))) {
       mlir::emitError(getOperation()->getLoc(),
                       "failure in HLFIR intrinsic simplification");
       signalPassFailure();
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-cshift.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-cshift.fir
index acb89c0719aa08..d21d7755062ba7 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-cshift.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-cshift.fir
@@ -1,17 +1,19 @@
 // Test hlfir.cshift simplification to hlfir.elemental:
 // RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s
 
-func.func @cshift_vector(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.ref<i32>) {
+func.func @cshift_vector(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.ref<i32>) -> !hlfir.expr<?xi32>{
   %res = hlfir.cshift %arg0 %arg1 : (!fir.box<!fir.array<?xi32>>, !fir.ref<i32>) -> !hlfir.expr<?xi32>
-  return
+  return %res : !hlfir.expr<?xi32>
 }
 // CHECK-LABEL:   func.func @cshift_vector(
 // CHECK-SAME:                             %[[VAL_0:.*]]: !fir.box<!fir.array<?xi32>>,
-// CHECK-SAME:                             %[[VAL_1:.*]]: !fir.ref<i32>) {
+// CHECK-SAME:                             %[[VAL_1:.*]]: !fir.ref<i32>) -> !hlfir.expr<?xi32> {
+// CHECK:           %[[VAL_26:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_16:.*]] = arith.constant 0 : i64
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : i64
 // CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_3:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_2]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
 // CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_3]]#1 : (index) -> !fir.shape<1>
-// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : i64
 // CHECK:           %[[VAL_6:.*]] = fir.load %[[VAL_1]] : !fir.ref<i32>
 // CHECK:           %[[VAL_7:.*]] = fir.convert %[[VAL_6]] : (i32) -> i64
 // CHECK:           %[[VAL_8:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xi32> {
@@ -22,7 +24,6 @@ func.func @cshift_vector(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.ref<i32
 // CHECK:             %[[VAL_13:.*]] = fir.convert %[[VAL_3]]#1 : (index) -> i64
 // CHECK:             %[[VAL_14:.*]] = arith.remsi %[[VAL_12]], %[[VAL_13]] : i64
 // CHECK:             %[[VAL_15:.*]] = arith.xori %[[VAL_12]], %[[VAL_13]] : i64
-// CHECK:             %[[VAL_16:.*]] = arith.constant 0 : i64
 // CHECK:             %[[VAL_17:.*]] = arith.cmpi slt, %[[VAL_15]], %[[VAL_16]] : i64
 // CHECK:             %[[VAL_18:.*]] = arith.cmpi ne, %[[VAL_14]], %[[VAL_16]] : i64
 // CHECK:             %[[VAL_19:.*]] = arith.andi %[[VAL_18]], %[[VAL_17]] : i1
@@ -30,9 +31,7 @@ func.func @cshift_vector(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.ref<i32
 // CHECK:             %[[VAL_21:.*]] = arith.select %[[VAL_19]], %[[VAL_20]], %[[VAL_14]] : i64
 // CHECK:             %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_5]] : i64
 // CHECK:             %[[VAL_23:.*]] = fir.convert %[[VAL_22]] : (i64) -> index
-// CHECK:             %[[VAL_24:.*]] = arith.constant 0 : index
-// CHECK:             %[[VAL_25:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_24]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
-// CHECK:             %[[VAL_26:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_25:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_2]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
 // CHECK:             %[[VAL_27:.*]] = arith.subi %[[VAL_25]]#0, %[[VAL_26]] : index
 // CHECK:             %[[VAL_28:.*]] = arith.addi %[[VAL_23]], %[[VAL_27]] : index
 // CHECK:             %[[VAL_29:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_28]])  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
@@ -42,21 +41,21 @@ func.func @cshift_vector(%arg0: !fir.box<!fir.array<?xi32>>, %arg1: !fir.ref<i32
 // CHECK:           return
 // CHECK:         }
 
-func.func @cshift_2d_by_scalar(%arg0: !fir.box<!fir.array<?x?xi32>>, %arg1: !fir.ref<i32>) {
+func.func @cshift_2d_by_scalar(%arg0: !fir.box<!fir.array<?x?xi32>>, %arg1: !fir.ref<i32>) -> !hlfir.expr<?x?xi32> {
   %dim = arith.constant 2 : i32
   %res = hlfir.cshift %arg0 %arg1 dim %dim : (!fir.box<!fir.array<?x?xi32>>, !fir.ref<i32>, i32) -> !hlfir.expr<?x?xi32>
-  return
+  return %res : !hlfir.expr<?x?xi32>
 }
 // CHECK-LABEL:   func.func @cshift_2d_by_scalar(
 // CHECK-SAME:                                   %[[VAL_0:.*]]: !fir.box<!fir.array<?x?xi32>>,
-// CHECK-SAME:                                   %[[VAL_1:.*]]: !fir.ref<i32>) {
-// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : i32
+// CHECK-SAME:                                   %[[VAL_1:.*]]: !fir.ref<i32>) -> !hlfir.expr<?x?xi32> {
+// CHECK:           %[[VAL_20:.*]] = arith.constant 0 : i64
+// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : i64
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_4:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
-// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_6:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_5]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
 // CHECK:           %[[VAL_7:.*]] = fir.shape %[[VAL_4]]#1, %[[VAL_6]]#1 : (index, index) -> !fir.shape<2>
-// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : i64
 // CHECK:           %[[VAL_9:.*]] = fir.load %[[VAL_1]] : !fir.ref<i32>
 // CHECK:           %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (i32) -> i64
 // CHECK:           %[[VAL_11:.*]] = hlfir.elemental %[[VAL_7]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xi32> {
@@ -67,7 +66,6 @@ func.func @cshift_2d_by_scalar(%arg0: !fir.box<!fir.array<?x?xi32>>, %arg1: !fir
 // CHECK:             %[[VAL_17:.*]] = fir.convert %[[VAL_6]]#1 : (index) -> i64
 // CHECK:             %[[VAL_18:.*]] = arith.remsi %[[VAL_16]], %[[VAL_17]] : i64
 // CHECK:             %[[VAL_19:.*]] = arith.xori %[[VAL_16]], %[[VAL_17]] : i64
-// CHECK:             %[[VAL_20:.*]] = arith.constant 0 : i64
 // CHECK:             %[[VAL_21:.*]] = arith.cmpi slt, %[[VAL_19]], %[[VAL_20]] : i64
 // CHECK:             %[[VAL_22:.*]] = arith.cmpi ne, %[[VAL_18]], %[[VAL_20]] : i64
 // CHECK:             %[[VAL_23:.*]] = arith.andi %[[VAL_22]], %[[VAL_21]] : i1
@@ -75,14 +73,11 @@ func.func @cshift_2d_by_scalar(%arg0: !fir.box<!fir.array<?x?xi32>>, %arg1: !fir
 // CHECK:             %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_18]] : i64
 // CHECK:             %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_8]] : i64
 // CHECK:             %[[VAL_27:.*]] = fir.convert %[[VAL_26]] : (i64) -> index
-// CHECK:             %[[VAL_28:.*]] = arith.constant 0 : index
-// CHECK:             %[[VAL_29:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_28]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
-// CHECK:             %[[VAL_30:.*]] = arith.constant 1 : index
-// CHECK:             %[[VAL_31:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_30]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
-// CHECK:             %[[VAL_32:.*]] = arith.constant 1 : index
-// CHECK:             %[[VAL_33:.*]] = arith.subi %[[VAL_29]]#0, %[[VAL_32]] : index
+// CHECK:             %[[VAL_29:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// CHECK:             %[[VAL_31:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_5]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
+// CHECK:             %[[VAL_33:.*]] = arith.subi %[[VAL_29]]#0, %[[VAL_5]] : index
 // CHECK:             %[[VAL_34:.*]] = arith.addi %[[VAL_12]], %[[VAL_33]] : index
-// CHECK:             %[[VAL_35:.*]] = arith.subi %[[VAL_31]]#0, %[[VAL_32]] : index
+// CHECK:             %[[VAL_35:.*]] = arith.subi %[[VAL_31]]#0, %[[VAL_5]] : index
 // CHECK:             %[[VAL_36:.*]] = arith.addi %[[VAL_27]], %[[VAL_35]] : index
 // CHECK:             %[[VAL_37:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_34]], %[[VAL_36]])  : (!fir.box<!fir.array<?x?xi32>>, index, index) -> !fir.ref<i32>
 // CHECK:             %[[VAL_38:.*]] = fir.load %[[VAL_37]] : !fir.ref<i32>
@@ -91,27 +86,25 @@ func.func @cshift_2d_by_scalar(%arg0: !fir.box<!fir.array<?x?xi32>>, %arg1: !fir
 // CHECK:           return
 // CHECK:         }
 
-func.func @cshift_2d_by_vector(%arg0: !fir.box<!fir.array<?x?xi32>>, %arg1: !fir.box<!fir.array<?xi32>>) {
+func.func @cshift_2d_by_vector(%arg0: !fir.box<!fir.array<?x?xi32>>, %arg1: !fir.box<!fir.array<?xi32>>) -> !hlfir.expr<?x?xi32> {
   %dim = arith.constant 2 : i32
   %res = hlfir.cshift %arg0 %arg1 dim %dim : (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?xi32>>, i32) -> !hlfir.expr<?x?xi32>
-  return
+  return %res : !hlfir.expr<?x?xi32>
 }
 // CHECK-LABEL:   func.func @cshift_2d_by_vector(
 // CHECK-SAME:                                   %[[VAL_0:.*]]: !fir.box<!fir.array<?x?xi32>>,
-// CHECK-SAME:                                   %[[VAL_1:.*]]: !fir.box<!fir.array<?xi32>>) {
-// CHECK:           %[[VAL_2:.*]] = arith.constant 2 : i32
+// CHECK-SAME:                                   %[[VAL_1:.*]]: !fir.box<!fir.array<?xi32>>) -> !hlfir.expr<?x?xi32> {
+// CHECK:           %[[VAL_26:.*]] = arith.constant 0 : i64
+// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : i64
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_4:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_3]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
-// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_6:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_5]] : (!fir.box<!fir.array<?x?xi32>>, index) -> (index, index, index)
 // CHECK:           %[[VAL_7:.*]] = fir.shape %[[VAL_4]]#1, %[[VAL_6]]#1 : (index, index) -> !fir.shape<2>
-// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : i64
 // CHECK:           %[[VAL_9:.*]] = hlfir.elemental %[[VAL_7]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xi32> {
 // CHECK:           ^bb0(%[[VAL_10:.*]]: index, %[[VAL_11:.*]]: index):
-// CHECK:             %[[VAL_12:.*]] = arith.constant 0 : index
-// CHECK:             %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_12]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
-// CHECK:             %[[VAL_14:.*]] = arith.constant 1 : index
-// CHECK:             %[[VAL_15:.*]] = arith.subi %[[VAL_13]]#0, %[[VAL_14]] : index
+// CHECK:             %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_1]], %[[VAL_3]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+// CHECK:             %[[VAL_15:.*]] = arith.subi %[[VAL_13]]#0, %[[VAL_5]] : index
 // CHECK:             %[[VAL_16:.*]] = arith.addi %[[VAL_10]], %[[VAL_15]] : index
 // CHECK:             %[[VAL_17:.*]] = hlfir.designate %[[VAL_1]] (%[[VAL_16]])  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
 // CHECK:             %[[VAL_18:.*]] = fir.load %[[VAL_17]] : !fir.ref<i32>
@@ -122,7 +115,6 @@ func.func @cshift_2d_by_ve...
[truncated]

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@vzakhari vzakhari merged commit 2402bcc into llvm:main Dec 16, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants