-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang] allow assumed-rank box in fir.store #95980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codegen is done with a memcpy using the rank from the "value" descriptor like for the fir.load case.
@llvm/pr-subscribers-flang-codegen Author: None (jeanPerier) ChangesCodegen is done with a memcpy using the rank from the "value" descriptor like for the fir.load case. Full diff: https://github.com/llvm/llvm-project/pull/95980.diff 4 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 4448224024f20..803d9e6086553 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3143,23 +3143,32 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = store.getLoc();
mlir::Type storeTy = store.getValue().getType();
- mlir::LLVM::StoreOp newStoreOp;
+ mlir::Value llvmValue = adaptor.getValue();
+ mlir::Value llvmMemref = adaptor.getMemref();
+ mlir::LLVM::AliasAnalysisOpInterface newOp;
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(storeTy)) {
- // fir.box value is actually in memory, load it first before storing it.
mlir::Type llvmBoxTy = lowerTy().convertBoxTypeAsStruct(boxTy);
- auto val = rewriter.create<mlir::LLVM::LoadOp>(loc, llvmBoxTy,
- adaptor.getOperands()[0]);
- attachTBAATag(val, boxTy, boxTy, nullptr);
- newStoreOp = rewriter.create<mlir::LLVM::StoreOp>(
- loc, val, adaptor.getOperands()[1]);
+ // fir.box value is actually in memory, load it first before storing it,
+ // or do a memcopy for assumed-rank descriptors.
+ if (boxTy.isAssumedRank()) {
+ TypePair boxTypePair{boxTy, llvmBoxTy};
+ mlir::Value boxSize =
+ computeBoxSize(loc, boxTypePair, llvmValue, rewriter);
+ newOp = rewriter.create<mlir::LLVM::MemcpyOp>(
+ loc, llvmMemref, llvmValue, boxSize, /*isVolatile=*/false);
+ } else {
+ auto val =
+ rewriter.create<mlir::LLVM::LoadOp>(loc, llvmBoxTy, llvmValue);
+ attachTBAATag(val, boxTy, boxTy, nullptr);
+ newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, val, llvmMemref);
+ }
} else {
- newStoreOp = rewriter.create<mlir::LLVM::StoreOp>(
- loc, adaptor.getOperands()[0], adaptor.getOperands()[1]);
+ newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
}
if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())
- newStoreOp.setTBAATags(*optionalTag);
+ newOp.setTBAATags(*optionalTag);
else
- attachTBAATag(newStoreOp, storeTy, storeTy, nullptr);
+ attachTBAATag(newOp, storeTy, storeTy, nullptr);
rewriter.eraseOp(store);
return mlir::success();
}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index ea8a9752eeeee..9b412deaae99b 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3793,8 +3793,6 @@ void fir::StoreOp::print(mlir::OpAsmPrinter &p) {
mlir::LogicalResult fir::StoreOp::verify() {
if (getValue().getType() != fir::dyn_cast_ptrEleTy(getMemref().getType()))
return emitOpError("store value type must match memory reference type");
- if (fir::isa_unknown_size_box(getValue().getType()))
- return emitOpError("cannot store !fir.box of unknown rank or type");
return mlir::success();
}
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index d7059671d3a88..f4a2475458f3a 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -893,6 +893,25 @@ func.func @store_unlimited_polymorphic_box(%arg0 : !fir.class<none>, %arg1 : !fi
// CHECK: llvm.store %[[VAL_11]], %{{.*}} : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i{{.*}}>>, ptr, array<1 x i{{.*}}>)>, !llvm.ptr
+// -----
+
+func.func @store_assumed_rank_box(%box: !fir.box<!fir.array<*:f32>>, %ref: !fir.ref<!fir.box<!fir.array<*:f32>>>) {
+ fir.store %box to %ref : !fir.ref<!fir.box<!fir.array<*:f32>>>
+ return
+}
+
+// CHECK-LABEL: llvm.func @store_assumed_rank_box(
+// CHECK-SAME: %[[VAL_0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr -> i8
+// CHECK: %[[VAL_5:.*]] = llvm.sext %[[VAL_4]] : i8 to i32
+// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK: %[[VAL_7:.*]] = llvm.mul %[[VAL_6]], %[[VAL_5]] : i32
+// CHECK: %[[VAL_8:.*]] = llvm.add %[[VAL_2]], %[[VAL_7]] : i32
+// CHECK: "llvm.intr.memcpy"(%[[VAL_1]], %[[VAL_0]], %[[VAL_8]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+
// -----
// Test `fir.load` --> `llvm.load` conversion
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 5800e608da41d..89679afc386c6 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -407,3 +407,22 @@ func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
// CHECK: %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
// CHECK: "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
// CHECK: llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()
+
+// -----
+
+func.func @store_assumed_rank_box(%box: !fir.box<!fir.array<*:f32>>, %ref: !fir.ref<!fir.box<!fir.array<*:f32>>>) {
+ fir.store %box to %ref : !fir.ref<!fir.box<!fir.array<*:f32>>>
+ return
+}
+
+// CHECK-DAG: #[[ROOT:.*]] = #llvm.tbaa_root<id = "Flang function root ">
+// CHECK-DAG: #[[ANYACC:.*]] = #llvm.tbaa_type_desc<id = "any access", members = {<#[[ROOT]], 0>}>
+// CHECK-DAG: #[[BOXMEM:.*]] = #llvm.tbaa_type_desc<id = "descriptor member", members = {<#[[ANYACC]], 0>}>
+// CHECK-DAG: #[[$BOXT:.*]] = #llvm.tbaa_tag<base_type = #[[BOXMEM]], access_type = #[[BOXMEM]], offset = 0>
+
+// CHECK-LABEL: llvm.func @store_assumed_rank_box(
+// CHECK-SAME: %[[VAL_0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
+// CHECK: "llvm.intr.memcpy"(%[[VAL_1]], %[[VAL_0]], %{{.*}}) <{isVolatile = false, tbaa = [#[[$BOXT]]]}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
|
@llvm/pr-subscribers-flang-fir-hlfir Author: None (jeanPerier) ChangesCodegen is done with a memcpy using the rank from the "value" descriptor like for the fir.load case. Full diff: https://github.com/llvm/llvm-project/pull/95980.diff 4 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 4448224024f20..803d9e6086553 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3143,23 +3143,32 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = store.getLoc();
mlir::Type storeTy = store.getValue().getType();
- mlir::LLVM::StoreOp newStoreOp;
+ mlir::Value llvmValue = adaptor.getValue();
+ mlir::Value llvmMemref = adaptor.getMemref();
+ mlir::LLVM::AliasAnalysisOpInterface newOp;
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(storeTy)) {
- // fir.box value is actually in memory, load it first before storing it.
mlir::Type llvmBoxTy = lowerTy().convertBoxTypeAsStruct(boxTy);
- auto val = rewriter.create<mlir::LLVM::LoadOp>(loc, llvmBoxTy,
- adaptor.getOperands()[0]);
- attachTBAATag(val, boxTy, boxTy, nullptr);
- newStoreOp = rewriter.create<mlir::LLVM::StoreOp>(
- loc, val, adaptor.getOperands()[1]);
+ // fir.box value is actually in memory, load it first before storing it,
+ // or do a memcopy for assumed-rank descriptors.
+ if (boxTy.isAssumedRank()) {
+ TypePair boxTypePair{boxTy, llvmBoxTy};
+ mlir::Value boxSize =
+ computeBoxSize(loc, boxTypePair, llvmValue, rewriter);
+ newOp = rewriter.create<mlir::LLVM::MemcpyOp>(
+ loc, llvmMemref, llvmValue, boxSize, /*isVolatile=*/false);
+ } else {
+ auto val =
+ rewriter.create<mlir::LLVM::LoadOp>(loc, llvmBoxTy, llvmValue);
+ attachTBAATag(val, boxTy, boxTy, nullptr);
+ newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, val, llvmMemref);
+ }
} else {
- newStoreOp = rewriter.create<mlir::LLVM::StoreOp>(
- loc, adaptor.getOperands()[0], adaptor.getOperands()[1]);
+ newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
}
if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())
- newStoreOp.setTBAATags(*optionalTag);
+ newOp.setTBAATags(*optionalTag);
else
- attachTBAATag(newStoreOp, storeTy, storeTy, nullptr);
+ attachTBAATag(newOp, storeTy, storeTy, nullptr);
rewriter.eraseOp(store);
return mlir::success();
}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index ea8a9752eeeee..9b412deaae99b 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3793,8 +3793,6 @@ void fir::StoreOp::print(mlir::OpAsmPrinter &p) {
mlir::LogicalResult fir::StoreOp::verify() {
if (getValue().getType() != fir::dyn_cast_ptrEleTy(getMemref().getType()))
return emitOpError("store value type must match memory reference type");
- if (fir::isa_unknown_size_box(getValue().getType()))
- return emitOpError("cannot store !fir.box of unknown rank or type");
return mlir::success();
}
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index d7059671d3a88..f4a2475458f3a 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -893,6 +893,25 @@ func.func @store_unlimited_polymorphic_box(%arg0 : !fir.class<none>, %arg1 : !fi
// CHECK: llvm.store %[[VAL_11]], %{{.*}} : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i{{.*}}>>, ptr, array<1 x i{{.*}}>)>, !llvm.ptr
+// -----
+
+func.func @store_assumed_rank_box(%box: !fir.box<!fir.array<*:f32>>, %ref: !fir.ref<!fir.box<!fir.array<*:f32>>>) {
+ fir.store %box to %ref : !fir.ref<!fir.box<!fir.array<*:f32>>>
+ return
+}
+
+// CHECK-LABEL: llvm.func @store_assumed_rank_box(
+// CHECK-SAME: %[[VAL_0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr -> i8
+// CHECK: %[[VAL_5:.*]] = llvm.sext %[[VAL_4]] : i8 to i32
+// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK: %[[VAL_7:.*]] = llvm.mul %[[VAL_6]], %[[VAL_5]] : i32
+// CHECK: %[[VAL_8:.*]] = llvm.add %[[VAL_2]], %[[VAL_7]] : i32
+// CHECK: "llvm.intr.memcpy"(%[[VAL_1]], %[[VAL_0]], %[[VAL_8]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+
// -----
// Test `fir.load` --> `llvm.load` conversion
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 5800e608da41d..89679afc386c6 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -407,3 +407,22 @@ func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
// CHECK: %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
// CHECK: "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
// CHECK: llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()
+
+// -----
+
+func.func @store_assumed_rank_box(%box: !fir.box<!fir.array<*:f32>>, %ref: !fir.ref<!fir.box<!fir.array<*:f32>>>) {
+ fir.store %box to %ref : !fir.ref<!fir.box<!fir.array<*:f32>>>
+ return
+}
+
+// CHECK-DAG: #[[ROOT:.*]] = #llvm.tbaa_root<id = "Flang function root ">
+// CHECK-DAG: #[[ANYACC:.*]] = #llvm.tbaa_type_desc<id = "any access", members = {<#[[ROOT]], 0>}>
+// CHECK-DAG: #[[BOXMEM:.*]] = #llvm.tbaa_type_desc<id = "descriptor member", members = {<#[[ANYACC]], 0>}>
+// CHECK-DAG: #[[$BOXT:.*]] = #llvm.tbaa_tag<base_type = #[[BOXMEM]], access_type = #[[BOXMEM]], offset = 0>
+
+// CHECK-LABEL: llvm.func @store_assumed_rank_box(
+// CHECK-SAME: %[[VAL_0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
+// CHECK: "llvm.intr.memcpy"(%[[VAL_1]], %[[VAL_0]], %{{.*}}) <{isVolatile = false, tbaa = [#[[$BOXT]]]}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Codegen is done with a memcpy using the rank from the "value" descriptor like for the fir.load case. Rational described in https://github.com/llvm/llvm-project/blob/main/flang/docs/AssumedRank.md.
Codegen is done with a memcpy using the rank from the "value" descriptor like for the fir.load case.
Rational described in https://github.com/llvm/llvm-project/blob/main/flang/docs/AssumedRank.md.