Skip to content

Commit 303249c

Browse files
authored
[flang][StackArrays] track pointers through fir.convert (#121919)
This does add a little computational complexity because now every freemem operation has to be tested for every allocation. This could be improved with some more memoisation but I think it is easier to read this way. Let me know if you would prefer me to change this to pre-compute the normalised addresses each freemem operation is using. Weirdly, this change resulted in a verifier failure for the fir.declare in the previous test case. Maybe it was previously removed as dead code and now it isn't. Anyway I fixed that too.
1 parent 44e8ee7 commit 303249c

File tree

2 files changed

+41
-16
lines changed

2 files changed

+41
-16
lines changed

flang/lib/Optimizer/Transforms/StackArrays.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,18 @@ std::optional<AllocationState> LatticePoint::get(mlir::Value val) const {
330330
return it->second;
331331
}
332332

333+
static mlir::Value lookThroughDeclaresAndConverts(mlir::Value value) {
334+
while (mlir::Operation *op = value.getDefiningOp()) {
335+
if (auto declareOp = llvm::dyn_cast<fir::DeclareOp>(op))
336+
value = declareOp.getMemref();
337+
else if (auto convertOp = llvm::dyn_cast<fir::ConvertOp>(op))
338+
value = convertOp->getOperand(0);
339+
else
340+
return value;
341+
}
342+
return value;
343+
}
344+
333345
mlir::LogicalResult AllocationAnalysis::visitOperation(
334346
mlir::Operation *op, const LatticePoint &before, LatticePoint *after) {
335347
LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op
@@ -363,10 +375,10 @@ mlir::LogicalResult AllocationAnalysis::visitOperation(
363375
mlir::Value operand = op->getOperand(0);
364376

365377
// Note: StackArrays is scheduled in the pass pipeline after lowering hlfir
366-
// to fir. Therefore, we only need to handle `fir::DeclareOp`s.
367-
if (auto declareOp =
368-
llvm::dyn_cast_if_present<fir::DeclareOp>(operand.getDefiningOp()))
369-
operand = declareOp.getMemref();
378+
// to fir. Therefore, we only need to handle `fir::DeclareOp`s. Also look
379+
// past converts in case the pointer was changed between different pointer
380+
// types.
381+
operand = lookThroughDeclaresAndConverts(operand);
370382

371383
std::optional<AllocationState> operandState = before.get(operand);
372384
if (operandState && *operandState == AllocationState::Allocated) {
@@ -535,17 +547,12 @@ AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
535547

536548
// remove freemem operations
537549
llvm::SmallVector<mlir::Operation *> erases;
538-
for (mlir::Operation *user : allocmem.getOperation()->getUsers()) {
539-
if (auto declareOp = mlir::dyn_cast_if_present<fir::DeclareOp>(user)) {
540-
for (mlir::Operation *user : declareOp->getUsers()) {
541-
if (mlir::isa<fir::FreeMemOp>(user))
542-
erases.push_back(user);
543-
}
544-
}
545-
546-
if (mlir::isa<fir::FreeMemOp>(user))
547-
erases.push_back(user);
548-
}
550+
mlir::Operation *parent = allocmem->getParentOp();
551+
// TODO: this shouldn't need to be re-calculated for every allocmem
552+
parent->walk([&](fir::FreeMemOp freeOp) {
553+
if (lookThroughDeclaresAndConverts(freeOp->getOperand(0)) == allocmem)
554+
erases.push_back(freeOp);
555+
});
549556

550557
// now we are done iterating the users, it is safe to mutate them
551558
for (mlir::Operation *erase : erases)

flang/test/Transforms/stack-arrays.fir

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,8 @@ func.func @placement_loop_declare() {
379379
%3 = arith.addi %c1, %c2 : index
380380
// operand is now available
381381
%4 = fir.allocmem !fir.array<?xi32>, %3
382-
%5 = fir.declare %4 {uniq_name = "temp"} : (!fir.heap<!fir.array<?xi32>>) -> !fir.heap<!fir.array<?xi32>>
382+
%shape = fir.shape %3 : (index) -> !fir.shape<1>
383+
%5 = fir.declare %4(%shape) {uniq_name = "temp"} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.heap<!fir.array<?xi32>>
383384
// ...
384385
fir.freemem %5 : !fir.heap<!fir.array<?xi32>>
385386
fir.result %3, %c1_i32 : index, i32
@@ -400,3 +401,20 @@ func.func @placement_loop_declare() {
400401
// CHECK-NEXT: }
401402
// CHECK-NEXT: return
402403
// CHECK-NEXT: }
404+
405+
// Can we look through fir.convert and fir.declare?
406+
func.func @lookthrough() {
407+
%0 = fir.allocmem !fir.array<42xi32>
408+
%c42 = arith.constant 42 : index
409+
%shape = fir.shape %c42 : (index) -> !fir.shape<1>
410+
%1 = fir.declare %0(%shape) {uniq_name = "name"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> !fir.heap<!fir.array<42xi32>>
411+
%2 = fir.convert %1 : (!fir.heap<!fir.array<42xi32>>) -> !fir.ref<!fir.array<42xi32>>
412+
// use the ref so the converts aren't folded
413+
%3 = fir.load %2 : !fir.ref<!fir.array<42xi32>>
414+
%4 = fir.convert %2 : (!fir.ref<!fir.array<42xi32>>) -> !fir.heap<!fir.array<42xi32>>
415+
fir.freemem %4 : !fir.heap<!fir.array<42xi32>>
416+
return
417+
}
418+
// CHECK: func.func @lookthrough() {
419+
// CHECK: fir.alloca !fir.array<42xi32>
420+
// CHECK-NOT: fir.freemem

0 commit comments

Comments
 (0)