Skip to content

[flang] Use LLVM dialect ops for stack save/restore in target-rewrite #107879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/CodeGen/CGPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def TargetRewritePass : Pass<"target-rewrite", "mlir::ModuleOp"> {
representations that may differ based on the target machine.
}];
let dependentDialects = [ "fir::FIROpsDialect", "mlir::func::FuncDialect",
"mlir::DLTIDialect" ];
"mlir::DLTIDialect", "mlir::LLVM::LLVMDialect" ];
let options = [
Option<"forcedTargetTriple", "target", "std::string", /*default=*/"",
"Override module's target triple.">,
Expand Down
23 changes: 12 additions & 11 deletions flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
Expand Down Expand Up @@ -114,13 +115,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {

setMembers(specifics.get(), &rewriter, &*dl);

// We may need to call stacksave/stackrestore later, so
// create the FuncOps beforehand.
fir::FirOpBuilder builder(rewriter, mod);
builder.setInsertionPointToStart(mod.getBody());
stackSaveFn = fir::factory::getLlvmStackSave(builder);
stackRestoreFn = fir::factory::getLlvmStackRestore(builder);

// Perform type conversion on signatures and call sites.
if (mlir::failed(convertTypes(mod))) {
mlir::emitError(mlir::UnknownLoc::get(&context),
Expand Down Expand Up @@ -1242,22 +1236,29 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {

inline void clearMembers() { setMembers(nullptr, nullptr, nullptr); }

uint64_t getAllocaAddressSpace() const {
if (dataLayout)
if (mlir::Attribute addrSpace = dataLayout->getAllocaMemorySpace())
return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
}

// Inserts a call to llvm.stacksave at the current insertion
// point and the given location. Returns the call's result Value.
inline mlir::Value genStackSave(mlir::Location loc) {
return rewriter->create<fir::CallOp>(loc, stackSaveFn).getResult(0);
mlir::Type voidPtr = mlir::LLVM::LLVMPointerType::get(
rewriter->getContext(), getAllocaAddressSpace());
return rewriter->create<mlir::LLVM::StackSaveOp>(loc, voidPtr);
}

// Inserts a call to llvm.stackrestore at the current insertion
// point and the given location and argument.
inline void genStackRestore(mlir::Location loc, mlir::Value sp) {
rewriter->create<fir::CallOp>(loc, stackRestoreFn, mlir::ValueRange{sp});
rewriter->create<mlir::LLVM::StackRestoreOp>(loc, sp);
}

fir::CodeGenSpecifics *specifics = nullptr;
mlir::OpBuilder *rewriter = nullptr;
mlir::DataLayout *dataLayout = nullptr;
mlir::func::FuncOp stackSaveFn = nullptr;
mlir::func::FuncOp stackRestoreFn = nullptr;
};
} // namespace
4 changes: 2 additions & 2 deletions flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ func.func @test_call_i16(%0 : !fir.ref<!fir.type<ti16{i:i16}>>) {
// CHECK-LABEL: func.func @test_call_i16(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<ti16{i:i16}>>) {
// CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.type<ti16{i:i16}>>
// CHECK: %[[VAL_2:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
// CHECK: %[[VAL_2:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_3:.*]] = fir.alloca i16
// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<i16>) -> !fir.ref<!fir.type<ti16{i:i16}>>
// CHECK: fir.store %[[VAL_1]] to %[[VAL_4]] : !fir.ref<!fir.type<ti16{i:i16}>>
// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_3]] : !fir.ref<i16>
// CHECK: fir.call @test_func_i16(%[[VAL_5]]) : (i16) -> ()
// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_2]]) : (!fir.ref<i8>) -> ()
// CHECK: llvm.intr.stackrestore %[[VAL_2]] : !llvm.ptr

func.func private @test_func_i16(%0 : !fir.type<ti16{i:i16}>) -> () {
return
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ func.func @test_call_i8_a16(%0 : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}
// CHECK-LABEL: func.func @test_call_i8_a16(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>) {
// CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
// CHECK: %[[VAL_2:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
// CHECK: %[[VAL_2:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_3:.*]] = fir.alloca tuple<i64, i64>
// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<tuple<i64, i64>>) -> !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
// CHECK: fir.store %[[VAL_1]] to %[[VAL_4]] : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_3]] : !fir.ref<tuple<i64, i64>>
// CHECK: %[[VAL_6:.*]] = fir.extract_value %[[VAL_5]], [0 : i32] : (tuple<i64, i64>) -> i64
// CHECK: %[[VAL_7:.*]] = fir.extract_value %[[VAL_5]], [1 : i32] : (tuple<i64, i64>) -> i64
// CHECK: fir.call @test_func_i8_a16(%[[VAL_6]], %[[VAL_7]]) : (i64, i64) -> ()
// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_2]]) : (!fir.ref<i8>) -> ()
// CHECK: llvm.intr.stackrestore %[[VAL_2]] : !llvm.ptr
// CHECK: return

func.func private @test_func_i8_a16(%0 : !fir.type<ti8_a16{a:!fir.array<16xi8>}>) -> () {
Expand Down
18 changes: 8 additions & 10 deletions flang/test/Fir/target-rewrite-complex16.fir
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,18 @@ func.func @addrof() {
// CHECK: func.func private @paramcomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>})

// CHECK-LABEL: func.func @callcomplex16() {
// CHECK: %[[VAL_0:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
// CHECK: %[[VAL_0:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_1:.*]] = fir.alloca tuple<!fir.real<16>, !fir.real<16>>
// CHECK: fir.call @returncomplex16(%[[VAL_1]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.complex<16>>
// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_0]]) : (!fir.ref<i8>) -> ()
// CHECK: %[[VAL_4:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
// CHECK: llvm.intr.stackrestore %[[VAL_0]] : !llvm.ptr
// CHECK: %[[VAL_4:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_5:.*]] = fir.alloca !fir.complex<16>
// CHECK: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_5]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
// CHECK: fir.call @paramcomplex16(%[[VAL_6]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_4]]) : (!fir.ref<i8>) -> ()
// CHECK: llvm.intr.stackrestore %[[VAL_4]] : !llvm.ptr
// CHECK: return
// CHECK: }
// CHECK: func.func private @calleemultipleparamscomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>})
Expand All @@ -87,7 +87,7 @@ func.func @addrof() {
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]] : !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_9:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
// CHECK: %[[VAL_9:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_10:.*]] = fir.alloca !fir.complex<16>
// CHECK: fir.store %[[VAL_8]] to %[[VAL_10]] : !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
Expand All @@ -98,7 +98,7 @@ func.func @addrof() {
// CHECK: fir.store %[[VAL_4]] to %[[VAL_14]] : !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
// CHECK: fir.call @calleemultipleparamscomplex16(%[[VAL_11]], %[[VAL_13]], %[[VAL_15]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_9]]) : (!fir.ref<i8>) -> ()
// CHECK: llvm.intr.stackrestore %[[VAL_9]] : !llvm.ptr
// CHECK: return
// CHECK: }

Expand All @@ -108,7 +108,7 @@ func.func @addrof() {
// CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref<complex<f128>>
// CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<complex<f128>>
// CHECK: %[[VAL_7:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
// CHECK: %[[VAL_7:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_8:.*]] = fir.alloca tuple<f128, f128>
// CHECK: %[[VAL_9:.*]] = fir.alloca complex<f128>
// CHECK: fir.store %[[VAL_6]] to %[[VAL_9]] : !fir.ref<complex<f128>>
Expand All @@ -119,7 +119,7 @@ func.func @addrof() {
// CHECK: fir.call @mlircomplexf128(%[[VAL_8]], %[[VAL_10]], %[[VAL_12]]) : (!fir.ref<tuple<f128, f128>>, !fir.ref<tuple<f128, f128>>, !fir.ref<tuple<f128, f128>>) -> ()
// CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_8]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
// CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<complex<f128>>
// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_7]]) : (!fir.ref<i8>) -> ()
// CHECK: llvm.intr.stackrestore %[[VAL_7]] : !llvm.ptr
// CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
// CHECK: fir.store %[[VAL_14]] to %[[VAL_15]] : !fir.ref<complex<f128>>
// CHECK: return
Expand All @@ -130,5 +130,3 @@ func.func @addrof() {
// CHECK: %[[VAL_1:.*]] = fir.address_of(@paramcomplex16) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
// CHECK: return
// CHECK: }
// CHECK: func.func private @llvm.stacksave.p0() -> !fir.ref<i8>
// CHECK: func.func private @llvm.stackrestore.p0(!fir.ref<i8>)
Loading