-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang] Use LLVM dialect ops for stack save/restore in target-rewrite #107879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-codegen @llvm/pr-subscribers-flang-fir-hlfir Author: None (jeanPerier) ChangesMostly NFC, I was bothered by the declaration that were always made even if unsued, and I think using LLVM Ops is nicer anyway with regards to side effects here.
There are other places in lowering that are using the calls instead of the LLVM intrinsics, but I will deal with them another time (the issue there is mostly to get the proper address space for the llvm.ptr type). Full diff: https://github.com/llvm/llvm-project/pull/107879.diff 5 Files Affected:
diff --git a/flang/include/flang/Optimizer/CodeGen/CGPasses.td b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
index e9e303df09eeba..2e097faec54036 100644
--- a/flang/include/flang/Optimizer/CodeGen/CGPasses.td
+++ b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
@@ -68,7 +68,7 @@ def TargetRewritePass : Pass<"target-rewrite", "mlir::ModuleOp"> {
representations that may differ based on the target machine.
}];
let dependentDialects = [ "fir::FIROpsDialect", "mlir::func::FuncDialect",
- "mlir::DLTIDialect" ];
+ "mlir::DLTIDialect", "mlir::LLVM::LLVMDialect" ];
let options = [
Option<"forcedTargetTriple", "target", "std::string", /*default=*/"",
"Override module's target triple.">,
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 85bf90e4750633..a2a9cff4c4977e 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -27,6 +27,7 @@
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -114,13 +115,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
setMembers(specifics.get(), &rewriter, &*dl);
- // We may need to call stacksave/stackrestore later, so
- // create the FuncOps beforehand.
- fir::FirOpBuilder builder(rewriter, mod);
- builder.setInsertionPointToStart(mod.getBody());
- stackSaveFn = fir::factory::getLlvmStackSave(builder);
- stackRestoreFn = fir::factory::getLlvmStackRestore(builder);
-
// Perform type conversion on signatures and call sites.
if (mlir::failed(convertTypes(mod))) {
mlir::emitError(mlir::UnknownLoc::get(&context),
@@ -1242,22 +1236,29 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
inline void clearMembers() { setMembers(nullptr, nullptr, nullptr); }
+ uint64_t getAllocaAddressSpace() const {
+ if (dataLayout)
+ if (mlir::Attribute addrSpace = dataLayout->getAllocaMemorySpace())
+ return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
+ return 0;
+ }
+
// Inserts a call to llvm.stacksave at the current insertion
// point and the given location. Returns the call's result Value.
inline mlir::Value genStackSave(mlir::Location loc) {
- return rewriter->create<fir::CallOp>(loc, stackSaveFn).getResult(0);
+ mlir::Type voidPtr = mlir::LLVM::LLVMPointerType::get(
+ rewriter->getContext(), getAllocaAddressSpace());
+ return rewriter->create<mlir::LLVM::StackSaveOp>(loc, voidPtr);
}
// Inserts a call to llvm.stackrestore at the current insertion
// point and the given location and argument.
inline void genStackRestore(mlir::Location loc, mlir::Value sp) {
- rewriter->create<fir::CallOp>(loc, stackRestoreFn, mlir::ValueRange{sp});
+ rewriter->create<mlir::LLVM::StackRestoreOp>(loc, sp);
}
fir::CodeGenSpecifics *specifics = nullptr;
mlir::OpBuilder *rewriter = nullptr;
mlir::DataLayout *dataLayout = nullptr;
- mlir::func::FuncOp stackSaveFn = nullptr;
- mlir::func::FuncOp stackRestoreFn = nullptr;
};
} // namespace
diff --git a/flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir b/flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir
index 9d4745becd8523..e37e8dd4481d06 100644
--- a/flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir
+++ b/flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir
@@ -13,13 +13,13 @@ func.func @test_call_i16(%0 : !fir.ref<!fir.type<ti16{i:i16}>>) {
// CHECK-LABEL: func.func @test_call_i16(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<ti16{i:i16}>>) {
// CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.type<ti16{i:i16}>>
-// CHECK: %[[VAL_2:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK: %[[VAL_2:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_3:.*]] = fir.alloca i16
// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<i16>) -> !fir.ref<!fir.type<ti16{i:i16}>>
// CHECK: fir.store %[[VAL_1]] to %[[VAL_4]] : !fir.ref<!fir.type<ti16{i:i16}>>
// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_3]] : !fir.ref<i16>
// CHECK: fir.call @test_func_i16(%[[VAL_5]]) : (i16) -> ()
-// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_2]]) : (!fir.ref<i8>) -> ()
+// CHECK: llvm.intr.stackrestore %[[VAL_2]] : !llvm.ptr
func.func private @test_func_i16(%0 : !fir.type<ti16{i:i16}>) -> () {
return
diff --git a/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir b/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir
index 82139492cea700..9a0a41e1da542a 100644
--- a/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir
+++ b/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir
@@ -14,7 +14,7 @@ func.func @test_call_i8_a16(%0 : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}
// CHECK-LABEL: func.func @test_call_i8_a16(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>) {
// CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
-// CHECK: %[[VAL_2:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK: %[[VAL_2:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_3:.*]] = fir.alloca tuple<i64, i64>
// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<tuple<i64, i64>>) -> !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
// CHECK: fir.store %[[VAL_1]] to %[[VAL_4]] : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
@@ -22,7 +22,7 @@ func.func @test_call_i8_a16(%0 : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}
// CHECK: %[[VAL_6:.*]] = fir.extract_value %[[VAL_5]], [0 : i32] : (tuple<i64, i64>) -> i64
// CHECK: %[[VAL_7:.*]] = fir.extract_value %[[VAL_5]], [1 : i32] : (tuple<i64, i64>) -> i64
// CHECK: fir.call @test_func_i8_a16(%[[VAL_6]], %[[VAL_7]]) : (i64, i64) -> ()
-// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_2]]) : (!fir.ref<i8>) -> ()
+// CHECK: llvm.intr.stackrestore %[[VAL_2]] : !llvm.ptr
// CHECK: return
func.func private @test_func_i8_a16(%0 : !fir.type<ti8_a16{a:!fir.array<16xi8>}>) -> () {
diff --git a/flang/test/Fir/target-rewrite-complex16.fir b/flang/test/Fir/target-rewrite-complex16.fir
index 69ee28ea337bf6..304f15a828454e 100644
--- a/flang/test/Fir/target-rewrite-complex16.fir
+++ b/flang/test/Fir/target-rewrite-complex16.fir
@@ -63,18 +63,18 @@ func.func @addrof() {
// CHECK: func.func private @paramcomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>})
// CHECK-LABEL: func.func @callcomplex16() {
-// CHECK: %[[VAL_0:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK: %[[VAL_0:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_1:.*]] = fir.alloca tuple<!fir.real<16>, !fir.real<16>>
// CHECK: fir.call @returncomplex16(%[[VAL_1]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.complex<16>>
-// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_0]]) : (!fir.ref<i8>) -> ()
-// CHECK: %[[VAL_4:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK: llvm.intr.stackrestore %[[VAL_0]] : !llvm.ptr
+// CHECK: %[[VAL_4:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_5:.*]] = fir.alloca !fir.complex<16>
// CHECK: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_5]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
// CHECK: fir.call @paramcomplex16(%[[VAL_6]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
-// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_4]]) : (!fir.ref<i8>) -> ()
+// CHECK: llvm.intr.stackrestore %[[VAL_4]] : !llvm.ptr
// CHECK: return
// CHECK: }
// CHECK: func.func private @calleemultipleparamscomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>})
@@ -87,7 +87,7 @@ func.func @addrof() {
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]] : !fir.ref<!fir.complex<16>>
-// CHECK: %[[VAL_9:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK: %[[VAL_9:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_10:.*]] = fir.alloca !fir.complex<16>
// CHECK: fir.store %[[VAL_8]] to %[[VAL_10]] : !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
@@ -98,7 +98,7 @@ func.func @addrof() {
// CHECK: fir.store %[[VAL_4]] to %[[VAL_14]] : !fir.ref<!fir.complex<16>>
// CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
// CHECK: fir.call @calleemultipleparamscomplex16(%[[VAL_11]], %[[VAL_13]], %[[VAL_15]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
-// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_9]]) : (!fir.ref<i8>) -> ()
+// CHECK: llvm.intr.stackrestore %[[VAL_9]] : !llvm.ptr
// CHECK: return
// CHECK: }
@@ -108,7 +108,7 @@ func.func @addrof() {
// CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref<complex<f128>>
// CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<complex<f128>>
-// CHECK: %[[VAL_7:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
+// CHECK: %[[VAL_7:.*]] = llvm.intr.stacksave : !llvm.ptr
// CHECK: %[[VAL_8:.*]] = fir.alloca tuple<f128, f128>
// CHECK: %[[VAL_9:.*]] = fir.alloca complex<f128>
// CHECK: fir.store %[[VAL_6]] to %[[VAL_9]] : !fir.ref<complex<f128>>
@@ -119,7 +119,7 @@ func.func @addrof() {
// CHECK: fir.call @mlircomplexf128(%[[VAL_8]], %[[VAL_10]], %[[VAL_12]]) : (!fir.ref<tuple<f128, f128>>, !fir.ref<tuple<f128, f128>>, !fir.ref<tuple<f128, f128>>) -> ()
// CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_8]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
// CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<complex<f128>>
-// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_7]]) : (!fir.ref<i8>) -> ()
+// CHECK: llvm.intr.stackrestore %[[VAL_7]] : !llvm.ptr
// CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
// CHECK: fir.store %[[VAL_14]] to %[[VAL_15]] : !fir.ref<complex<f128>>
// CHECK: return
@@ -130,5 +130,3 @@ func.func @addrof() {
// CHECK: %[[VAL_1:.*]] = fir.address_of(@paramcomplex16) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
// CHECK: return
// CHECK: }
-// CHECK: func.func private @llvm.stacksave.p0() -> !fir.ref<i8>
-// CHECK: func.func private @llvm.stackrestore.p0(!fir.ref<i8>)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, I hadn't noticed these operations existed.
There are other places in flang that we use stacksave and stackrestore. If you don't plan to update these please could you make a ticket and assign to me so I don't forget.
Thank you for addressing this! When doing this it would be greatly appreciated if you utilise the alloca address spaces similar to this PR! As, otherwise it runs the risk of breaking ISEL for AMDGPU if pulled in, the lowering flow for AMDGPU gets a little angry and breaks when address spaces aren't adhered to. |
The new LLVM stack save/restore intrinsic operations are more convenient than function calls because they do not add function declarations to the module and therefore do not block the parallelisation of passes. Furthermore they could be much more easily marked with memory effects than function calls if that ever proved useful. This builds on top of llvm#107879. Resolves llvm#108016
…108562) The new LLVM stack save/restore intrinsic operations are more convenient than function calls because they do not add function declarations to the module and therefore do not block the parallelisation of passes. Furthermore they could be much more easily marked with memory effects than function calls if that ever proved useful. This builds on top of #107879. Resolves #108016
Mostly NFC, I was bothered by the declaration that were always made even if unsued, and I think using LLVM Ops is nicer anyway with regards to side effects here.
There are other places in lowering that are using the calls instead of the LLVM intrinsics, but I will deal with them another time (the issue there is mostly to get the proper address space for the llvm.ptr type).