Skip to content

[flang] allow assumed-rank box in fir.store #95980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 19, 2024

Conversation

jeanPerier
Copy link
Contributor

Codegen is done with a memcpy using the rank from the "value" descriptor like for the fir.load case.
Rational described in https://github.com/llvm/llvm-project/blob/main/flang/docs/AssumedRank.md.

Codegen is done with a memcpy using the rank from the "value"
descriptor like for the fir.load case.
@jeanPerier jeanPerier requested review from clementval and tblah June 18, 2024 20:00
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Jun 18, 2024
@llvmbot
Copy link
Member

llvmbot commented Jun 18, 2024

@llvm/pr-subscribers-flang-codegen

Author: None (jeanPerier)

Changes

Codegen is done with a memcpy using the rank from the "value" descriptor like for the fir.load case.
Rational described in https://github.com/llvm/llvm-project/blob/main/flang/docs/AssumedRank.md.


Full diff: https://github.com/llvm/llvm-project/pull/95980.diff

4 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+20-11)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (-2)
  • (modified) flang/test/Fir/convert-to-llvm.fir (+19)
  • (modified) flang/test/Fir/tbaa.fir (+19)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 4448224024f20..803d9e6086553 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3143,23 +3143,32 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Location loc = store.getLoc();
     mlir::Type storeTy = store.getValue().getType();
-    mlir::LLVM::StoreOp newStoreOp;
+    mlir::Value llvmValue = adaptor.getValue();
+    mlir::Value llvmMemref = adaptor.getMemref();
+    mlir::LLVM::AliasAnalysisOpInterface newOp;
     if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(storeTy)) {
-      // fir.box value is actually in memory, load it first before storing it.
       mlir::Type llvmBoxTy = lowerTy().convertBoxTypeAsStruct(boxTy);
-      auto val = rewriter.create<mlir::LLVM::LoadOp>(loc, llvmBoxTy,
-                                                     adaptor.getOperands()[0]);
-      attachTBAATag(val, boxTy, boxTy, nullptr);
-      newStoreOp = rewriter.create<mlir::LLVM::StoreOp>(
-          loc, val, adaptor.getOperands()[1]);
+      // fir.box value is actually in memory, load it first before storing it,
+      // or do a memcopy for assumed-rank descriptors.
+      if (boxTy.isAssumedRank()) {
+        TypePair boxTypePair{boxTy, llvmBoxTy};
+        mlir::Value boxSize =
+            computeBoxSize(loc, boxTypePair, llvmValue, rewriter);
+        newOp = rewriter.create<mlir::LLVM::MemcpyOp>(
+            loc, llvmMemref, llvmValue, boxSize, /*isVolatile=*/false);
+      } else {
+        auto val =
+            rewriter.create<mlir::LLVM::LoadOp>(loc, llvmBoxTy, llvmValue);
+        attachTBAATag(val, boxTy, boxTy, nullptr);
+        newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, val, llvmMemref);
+      }
     } else {
-      newStoreOp = rewriter.create<mlir::LLVM::StoreOp>(
-          loc, adaptor.getOperands()[0], adaptor.getOperands()[1]);
+      newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
     }
     if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())
-      newStoreOp.setTBAATags(*optionalTag);
+      newOp.setTBAATags(*optionalTag);
     else
-      attachTBAATag(newStoreOp, storeTy, storeTy, nullptr);
+      attachTBAATag(newOp, storeTy, storeTy, nullptr);
     rewriter.eraseOp(store);
     return mlir::success();
   }
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index ea8a9752eeeee..9b412deaae99b 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3793,8 +3793,6 @@ void fir::StoreOp::print(mlir::OpAsmPrinter &p) {
 mlir::LogicalResult fir::StoreOp::verify() {
   if (getValue().getType() != fir::dyn_cast_ptrEleTy(getMemref().getType()))
     return emitOpError("store value type must match memory reference type");
-  if (fir::isa_unknown_size_box(getValue().getType()))
-    return emitOpError("cannot store !fir.box of unknown rank or type");
   return mlir::success();
 }
 
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index d7059671d3a88..f4a2475458f3a 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -893,6 +893,25 @@ func.func @store_unlimited_polymorphic_box(%arg0 : !fir.class<none>, %arg1 : !fi
 // CHECK:  llvm.store %[[VAL_11]], %{{.*}} : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i{{.*}}>>, ptr, array<1 x i{{.*}}>)>, !llvm.ptr
 
 
+// -----
+
+func.func @store_assumed_rank_box(%box: !fir.box<!fir.array<*:f32>>, %ref: !fir.ref<!fir.box<!fir.array<*:f32>>>) {
+  fir.store %box to %ref : !fir.ref<!fir.box<!fir.array<*:f32>>>
+  return
+}
+
+// CHECK-LABEL:   llvm.func @store_assumed_rank_box(
+// CHECK-SAME:                                      %[[VAL_0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME:                                      %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK:           %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr -> i8
+// CHECK:           %[[VAL_5:.*]] = llvm.sext %[[VAL_4]] : i8 to i32
+// CHECK:           %[[VAL_6:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_7:.*]] = llvm.mul %[[VAL_6]], %[[VAL_5]] : i32
+// CHECK:           %[[VAL_8:.*]] = llvm.add %[[VAL_2]], %[[VAL_7]] : i32
+// CHECK:           "llvm.intr.memcpy"(%[[VAL_1]], %[[VAL_0]], %[[VAL_8]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+
 // -----
 
 // Test `fir.load` --> `llvm.load` conversion
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 5800e608da41d..89679afc386c6 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -407,3 +407,22 @@ func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
 // CHECK:           %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
 // CHECK:           "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
 // CHECK:           llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()
+
+// -----
+
+func.func @store_assumed_rank_box(%box: !fir.box<!fir.array<*:f32>>, %ref: !fir.ref<!fir.box<!fir.array<*:f32>>>) {
+  fir.store %box to %ref : !fir.ref<!fir.box<!fir.array<*:f32>>>
+  return
+}
+
+// CHECK-DAG:     #[[ROOT:.*]] = #llvm.tbaa_root<id = "Flang function root ">
+// CHECK-DAG:     #[[ANYACC:.*]] = #llvm.tbaa_type_desc<id = "any access", members = {<#[[ROOT]], 0>}>
+// CHECK-DAG:     #[[BOXMEM:.*]] = #llvm.tbaa_type_desc<id = "descriptor member", members = {<#[[ANYACC]], 0>}>
+// CHECK-DAG:     #[[$BOXT:.*]] = #llvm.tbaa_tag<base_type = #[[BOXMEM]], access_type = #[[BOXMEM]], offset = 0>
+
+// CHECK-LABEL:   llvm.func @store_assumed_rank_box(
+// CHECK-SAME:                                      %[[VAL_0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME:                                      %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK:           %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK:           %[[VAL_4:.*]] = llvm.load %[[VAL_3]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
+// CHECK:           "llvm.intr.memcpy"(%[[VAL_1]], %[[VAL_0]], %{{.*}}) <{isVolatile = false, tbaa = [#[[$BOXT]]]}> : (!llvm.ptr, !llvm.ptr, i32) -> ()

@llvmbot
Copy link
Member

llvmbot commented Jun 18, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (jeanPerier)

Changes

Codegen is done with a memcpy using the rank from the "value" descriptor like for the fir.load case.
Rational described in https://github.com/llvm/llvm-project/blob/main/flang/docs/AssumedRank.md.


Full diff: https://github.com/llvm/llvm-project/pull/95980.diff

4 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+20-11)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (-2)
  • (modified) flang/test/Fir/convert-to-llvm.fir (+19)
  • (modified) flang/test/Fir/tbaa.fir (+19)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 4448224024f20..803d9e6086553 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3143,23 +3143,32 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
                   mlir::ConversionPatternRewriter &rewriter) const override {
     mlir::Location loc = store.getLoc();
     mlir::Type storeTy = store.getValue().getType();
-    mlir::LLVM::StoreOp newStoreOp;
+    mlir::Value llvmValue = adaptor.getValue();
+    mlir::Value llvmMemref = adaptor.getMemref();
+    mlir::LLVM::AliasAnalysisOpInterface newOp;
     if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(storeTy)) {
-      // fir.box value is actually in memory, load it first before storing it.
       mlir::Type llvmBoxTy = lowerTy().convertBoxTypeAsStruct(boxTy);
-      auto val = rewriter.create<mlir::LLVM::LoadOp>(loc, llvmBoxTy,
-                                                     adaptor.getOperands()[0]);
-      attachTBAATag(val, boxTy, boxTy, nullptr);
-      newStoreOp = rewriter.create<mlir::LLVM::StoreOp>(
-          loc, val, adaptor.getOperands()[1]);
+      // fir.box value is actually in memory, load it first before storing it,
+      // or do a memcopy for assumed-rank descriptors.
+      if (boxTy.isAssumedRank()) {
+        TypePair boxTypePair{boxTy, llvmBoxTy};
+        mlir::Value boxSize =
+            computeBoxSize(loc, boxTypePair, llvmValue, rewriter);
+        newOp = rewriter.create<mlir::LLVM::MemcpyOp>(
+            loc, llvmMemref, llvmValue, boxSize, /*isVolatile=*/false);
+      } else {
+        auto val =
+            rewriter.create<mlir::LLVM::LoadOp>(loc, llvmBoxTy, llvmValue);
+        attachTBAATag(val, boxTy, boxTy, nullptr);
+        newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, val, llvmMemref);
+      }
     } else {
-      newStoreOp = rewriter.create<mlir::LLVM::StoreOp>(
-          loc, adaptor.getOperands()[0], adaptor.getOperands()[1]);
+      newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
     }
     if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())
-      newStoreOp.setTBAATags(*optionalTag);
+      newOp.setTBAATags(*optionalTag);
     else
-      attachTBAATag(newStoreOp, storeTy, storeTy, nullptr);
+      attachTBAATag(newOp, storeTy, storeTy, nullptr);
     rewriter.eraseOp(store);
     return mlir::success();
   }
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index ea8a9752eeeee..9b412deaae99b 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3793,8 +3793,6 @@ void fir::StoreOp::print(mlir::OpAsmPrinter &p) {
 mlir::LogicalResult fir::StoreOp::verify() {
   if (getValue().getType() != fir::dyn_cast_ptrEleTy(getMemref().getType()))
     return emitOpError("store value type must match memory reference type");
-  if (fir::isa_unknown_size_box(getValue().getType()))
-    return emitOpError("cannot store !fir.box of unknown rank or type");
   return mlir::success();
 }
 
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index d7059671d3a88..f4a2475458f3a 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -893,6 +893,25 @@ func.func @store_unlimited_polymorphic_box(%arg0 : !fir.class<none>, %arg1 : !fi
 // CHECK:  llvm.store %[[VAL_11]], %{{.*}} : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i{{.*}}>>, ptr, array<1 x i{{.*}}>)>, !llvm.ptr
 
 
+// -----
+
+func.func @store_assumed_rank_box(%box: !fir.box<!fir.array<*:f32>>, %ref: !fir.ref<!fir.box<!fir.array<*:f32>>>) {
+  fir.store %box to %ref : !fir.ref<!fir.box<!fir.array<*:f32>>>
+  return
+}
+
+// CHECK-LABEL:   llvm.func @store_assumed_rank_box(
+// CHECK-SAME:                                      %[[VAL_0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME:                                      %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK:           %[[VAL_2:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK:           %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr -> i8
+// CHECK:           %[[VAL_5:.*]] = llvm.sext %[[VAL_4]] : i8 to i32
+// CHECK:           %[[VAL_6:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_7:.*]] = llvm.mul %[[VAL_6]], %[[VAL_5]] : i32
+// CHECK:           %[[VAL_8:.*]] = llvm.add %[[VAL_2]], %[[VAL_7]] : i32
+// CHECK:           "llvm.intr.memcpy"(%[[VAL_1]], %[[VAL_0]], %[[VAL_8]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+
 // -----
 
 // Test `fir.load` --> `llvm.load` conversion
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 5800e608da41d..89679afc386c6 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -407,3 +407,22 @@ func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
 // CHECK:           %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
 // CHECK:           "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
 // CHECK:           llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()
+
+// -----
+
+func.func @store_assumed_rank_box(%box: !fir.box<!fir.array<*:f32>>, %ref: !fir.ref<!fir.box<!fir.array<*:f32>>>) {
+  fir.store %box to %ref : !fir.ref<!fir.box<!fir.array<*:f32>>>
+  return
+}
+
+// CHECK-DAG:     #[[ROOT:.*]] = #llvm.tbaa_root<id = "Flang function root ">
+// CHECK-DAG:     #[[ANYACC:.*]] = #llvm.tbaa_type_desc<id = "any access", members = {<#[[ROOT]], 0>}>
+// CHECK-DAG:     #[[BOXMEM:.*]] = #llvm.tbaa_type_desc<id = "descriptor member", members = {<#[[ANYACC]], 0>}>
+// CHECK-DAG:     #[[$BOXT:.*]] = #llvm.tbaa_tag<base_type = #[[BOXMEM]], access_type = #[[BOXMEM]], offset = 0>
+
+// CHECK-LABEL:   llvm.func @store_assumed_rank_box(
+// CHECK-SAME:                                      %[[VAL_0:[^:]*]]: !llvm.ptr,
+// CHECK-SAME:                                      %[[VAL_1:.*]]: !llvm.ptr) {
+// CHECK:           %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK:           %[[VAL_4:.*]] = llvm.load %[[VAL_3]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
+// CHECK:           "llvm.intr.memcpy"(%[[VAL_1]], %[[VAL_0]], %{{.*}}) <{isVolatile = false, tbaa = [#[[$BOXT]]]}> : (!llvm.ptr, !llvm.ptr, i32) -> ()

Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jeanPerier jeanPerier merged commit a786919 into llvm:main Jun 19, 2024
11 checks passed
@jeanPerier jeanPerier deleted the jpr-assumed-rank-store branch June 19, 2024 08:12
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
Codegen is done with a memcpy using the rank from the "value" descriptor
like for the fir.load case.
Rational described in
https://github.com/llvm/llvm-project/blob/main/flang/docs/AssumedRank.md.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants