@@ -87,6 +87,13 @@ class ElementalAssignBufferization
87
87
// / determines if the transformation can be applied to this elemental
88
88
static std::optional<MatchInfo> findMatch (hlfir::ElementalOp elemental);
89
89
90
+ // / Returns the array indices for the given hlfir.designate.
91
+ // / It recognizes the computations used to transform the one-based indices
92
+ // / into the array's lb-based indices, and returns the one-based indices
93
+ // / in these cases.
94
+ static llvm::SmallVector<mlir::Value>
95
+ getDesignatorIndices (hlfir::DesignateOp designate);
96
+
90
97
public:
91
98
using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
92
99
@@ -430,6 +437,73 @@ bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) {
430
437
return false ;
431
438
}
432
439
440
+ llvm::SmallVector<mlir::Value>
441
+ ElementalAssignBufferization::getDesignatorIndices (
442
+ hlfir::DesignateOp designate) {
443
+ mlir::Value memref = designate.getMemref ();
444
+
445
+ // If the object is a box, then the indices may be adjusted
446
+ // according to the box's lower bound(s). Scan through
447
+ // the computations to try to find the one-based indices.
448
+ if (mlir::isa<fir::BaseBoxType>(memref.getType ())) {
449
+ // Look for the following pattern:
450
+ // %13 = fir.load %12 : !fir.ref<!fir.box<...>
451
+ // %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ...
452
+ // %17 = arith.subi %14#0, %c1 : index
453
+ // %18 = arith.addi %arg2, %17 : index
454
+ // %19 = hlfir.designate %13 (%18) : (!fir.box<...>, index) -> ...
455
+ //
456
+ // %arg2 is a one-based index.
457
+
458
+ auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) {
459
+ // Return true, if v and dim are such that:
460
+ // %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ...
461
+ // %17 = arith.subi %14#0, %c1 : index
462
+ // %19 = hlfir.designate %13 (...) : (!fir.box<...>, index) -> ...
463
+ if (auto subOp =
464
+ mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp ())) {
465
+ auto cst = fir::getIntIfConstant (subOp.getRhs ());
466
+ if (!cst || *cst != 1 )
467
+ return false ;
468
+ if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>(
469
+ subOp.getLhs ().getDefiningOp ())) {
470
+ if (memref != dimsOp.getVal () ||
471
+ dimsOp.getResult (0 ) != subOp.getLhs ())
472
+ return false ;
473
+ auto dimsOpDim = fir::getIntIfConstant (dimsOp.getDim ());
474
+ return dimsOpDim && dimsOpDim == dim;
475
+ }
476
+ }
477
+ return false ;
478
+ };
479
+
480
+ llvm::SmallVector<mlir::Value> newIndices;
481
+ for (auto index : llvm::enumerate (designate.getIndices ())) {
482
+ if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>(
483
+ index.value ().getDefiningOp ())) {
484
+ for (unsigned opNum = 0 ; opNum < 2 ; ++opNum)
485
+ if (isNormalizedLb (addOp->getOperand (opNum), index.index ())) {
486
+ newIndices.push_back (addOp->getOperand ((opNum + 1 ) % 2 ));
487
+ break ;
488
+ }
489
+
490
+ // If new one-based index was not added, exit early.
491
+ if (newIndices.size () <= index.index ())
492
+ break ;
493
+ }
494
+ }
495
+
496
+ // If any of the indices is not adjusted to the array's lb,
497
+ // then return the original designator indices.
498
+ if (newIndices.size () != designate.getIndices ().size ())
499
+ return designate.getIndices ();
500
+
501
+ return newIndices;
502
+ }
503
+
504
+ return designate.getIndices ();
505
+ }
506
+
433
507
std::optional<ElementalAssignBufferization::MatchInfo>
434
508
ElementalAssignBufferization::findMatch (hlfir::ElementalOp elemental) {
435
509
mlir::Operation::user_range users = elemental->getUsers ();
@@ -557,7 +631,7 @@ ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
557
631
<< " at " << elemental.getLoc () << " \n " );
558
632
return std::nullopt;
559
633
}
560
- auto indices = designate. getIndices ( );
634
+ auto indices = getDesignatorIndices (designate );
561
635
auto elementalIndices = elemental.getIndices ();
562
636
if (indices.size () == elementalIndices.size () &&
563
637
std::equal (indices.begin (), indices.end (), elementalIndices.begin (),
0 commit comments