Skip to content

[flang] Implement conversion of compatible derived types #111165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 9, 2024

Conversation

luporl
Copy link
Contributor

@luporl luporl commented Oct 4, 2024

With some restrictions, BIND(C) derived types can be converted to
compatible BIND(C) derived types.
Semantics already support this, but ConvertOp was missing the
conversion of such types.

Fixes #107783

With some restrictions, BIND(C) derived types can be converted to
compatible BIND(C) derived types.
Semantics already support this, but ConvertOp was missing the
conversion of such types.

Fixes llvm#107783
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Oct 4, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 4, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Leandro Lupori (luporl)

Changes

With some restrictions, BIND(C) derived types can be converted to
compatible BIND(C) derived types.
Semantics already support this, but ConvertOp was missing the
conversion of such types.

Fixes #107783


Full diff: https://github.com/llvm/llvm-project/pull/111165.diff

5 Files Affected:

  • (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+4-1)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+25)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+11-1)
  • (modified) flang/test/Fir/convert-to-llvm.fir (+25)
  • (modified) flang/test/Fir/invalid.fir (+8)
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 9ad37c8df434a2..8fa695a5c0c2e1 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -479,7 +479,10 @@ mlir::Value fir::factory::createConvert(mlir::OpBuilder &builder,
                                         mlir::Location loc, mlir::Type toTy,
                                         mlir::Value val) {
   if (val.getType() != toTy) {
-    assert(!fir::isa_derived(toTy));
+    assert((!fir::isa_derived(toTy) ||
+            mlir::cast<fir::RecordType>(val.getType()).getTypeList() ==
+                mlir::cast<fir::RecordType>(toTy).getTypeList()) &&
+           "incompatible record types");
     return builder.create<fir::ConvertOp>(loc, toTy, val);
   }
   return val;
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 1cb869bfeb95a8..19c38a1ba6be26 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -660,6 +660,31 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
     auto loc = convert.getLoc();
     auto i1Type = mlir::IntegerType::get(convert.getContext(), 1);
 
+    if (mlir::isa<fir::RecordType>(toFirTy)) {
+      // Convert to compatible BIND(C) record type.
+      // Double check that the record types are compatible (it should have
+      // already been checked by the verifier).
+      assert(mlir::cast<fir::RecordType>(fromFirTy).getTypeList() ==
+                 mlir::cast<fir::RecordType>(toFirTy).getTypeList() &&
+             "incompatible record types");
+
+      auto toStTy = mlir::cast<mlir::LLVM::LLVMStructType>(toTy);
+      mlir::Value val = rewriter.create<mlir::LLVM::UndefOp>(loc, toStTy);
+      auto indexTypeMap = toStTy.getSubelementIndexMap();
+      assert(indexTypeMap.has_value() && "invalid record type");
+
+      for (auto [attr, type] : indexTypeMap.value()) {
+        int64_t index = mlir::cast<mlir::IntegerAttr>(attr).getInt();
+        auto extVal =
+            rewriter.create<mlir::LLVM::ExtractValueOp>(loc, op0, index);
+        val =
+            rewriter.create<mlir::LLVM::InsertValueOp>(loc, val, extVal, index);
+      }
+
+      rewriter.replaceOp(convert, val);
+      return mlir::success();
+    }
+
     if (mlir::isa<fir::LogicalType>(fromFirTy) ||
         mlir::isa<fir::LogicalType>(toFirTy)) {
       // By specification fir::LogicalType value may be any number,
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 8fdc06f6fce3f5..90ce8b87605912 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1410,6 +1410,15 @@ bool fir::ConvertOp::areVectorsCompatible(mlir::Type inTy, mlir::Type outTy) {
   return true;
 }
 
+static bool areRecordsCompatible(mlir::Type inTy, mlir::Type outTy) {
+  // Both records must have the same field types.
+  // Trust frontend semantics for in-depth checks, such as if both records
+  // have the BIND(C) attribute.
+  auto inRecTy = mlir::dyn_cast<fir::RecordType>(inTy);
+  auto outRecTy = mlir::dyn_cast<fir::RecordType>(outTy);
+  return inRecTy && outRecTy && inRecTy.getTypeList() == outRecTy.getTypeList();
+}
+
 bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) {
   if (inType == outType)
     return true;
@@ -1428,7 +1437,8 @@ bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) {
          (fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) ||
          (fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType)) ||
          (fir::isPolymorphicType(inType) && mlir::isa<BoxType>(outType)) ||
-         areVectorsCompatible(inType, outType);
+         areVectorsCompatible(inType, outType) ||
+         areRecordsCompatible(inType, outType);
 }
 
 llvm::LogicalResult fir::ConvertOp::verify() {
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 0c17d7c25a8c8d..1182a0a10f218b 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -816,6 +816,31 @@ func.func @convert_complex16(%arg0 : complex<f128>) -> complex<f16> {
 
 // -----
 
+// Test `fir.convert` operation conversion between compatible fir.record types.
+
+func.func @convert_record(%arg0 : !fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>) ->
+                                  !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}> {
+    %0 = fir.convert %arg0 : (!fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>) ->
+                              !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>
+  return %0 : !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>
+}
+
+// CHECK-LABEL: func @convert_record(
+// CHECK-SAME:    %[[ARG0:.*]]: [[MOD1_REC:!llvm.struct<"_QMmod1Trec", \(i32, f64, struct<\(f32, f32\)>, array<4 x array<1 x i8>>\)>]]) ->
+// CHECK-SAME:                  [[MOD2_REC:!llvm.struct<"_QMmod2Trec", \(i32, f64, struct<\(f32, f32\)>, array<4 x array<1 x i8>>\)>]]
+// CHECK:         %{{.*}} = llvm.mlir.undef : [[MOD2_REC]]
+// CHECK-DAG:     %[[I:.*]] = llvm.extractvalue %[[ARG0]][0] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[I]], %{{.*}}[0] : [[MOD2_REC]]
+// CHECK-DAG:     %[[F:.*]] = llvm.extractvalue %[[ARG0]][1] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[F]], %{{.*}}[1] : [[MOD2_REC]]
+// CHECK-DAG:     %[[C:.*]] = llvm.extractvalue %[[ARG0]][2] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[C]], %{{.*}}[2] : [[MOD2_REC]]
+// CHECK-DAG:     %[[CSTR:.*]] = llvm.extractvalue %[[ARG0]][3] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[CSTR]], %{{.*}}[3] : [[MOD2_REC]]
+// CHECK:         llvm.return %{{.*}} : [[MOD2_REC]]
+
+// -----
+
 // Test `fir.store` --> `llvm.store` conversion
 
 func.func @test_store_index(%val_to_store : index, %addr : !fir.ref<index>) {
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index 086a426db5642e..7e3f9d64984129 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -965,6 +965,14 @@ func.func @fp_to_logical(%arg0: f32) -> !fir.logical<4> {
 
 // -----
 
+func.func @rec_to_rec(%arg0: !fir.type<t1{i:i32, f:f32}>) -> !fir.type<t2{f:f32, i:i32}> {
+  // expected-error@+1{{'fir.convert' op invalid type conversion}}
+  %0 = fir.convert %arg0 : (!fir.type<t1{i:i32, f:f32}>) -> !fir.type<t2{f:f32, i:i32}>
+  return %0 : !fir.type<t2{f:f32, i:i32}>
+}
+
+// -----
+
 func.func @bad_box_offset(%not_a_box : !fir.ref<i32>) {
   // expected-error@+1{{'fir.box_offset' op box_ref operand must have !fir.ref<!fir.box<T>> type}}
   %addr1 = fir.box_offset %not_a_box base_addr : (!fir.ref<i32>) -> !fir.llvm_ptr<!fir.ref<i32>>

@llvmbot
Copy link
Member

llvmbot commented Oct 4, 2024

@llvm/pr-subscribers-flang-codegen

Author: Leandro Lupori (luporl)

Changes

With some restrictions, BIND(C) derived types can be converted to
compatible BIND(C) derived types.
Semantics already support this, but ConvertOp was missing the
conversion of such types.

Fixes #107783


Full diff: https://github.com/llvm/llvm-project/pull/111165.diff

5 Files Affected:

  • (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+4-1)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+25)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+11-1)
  • (modified) flang/test/Fir/convert-to-llvm.fir (+25)
  • (modified) flang/test/Fir/invalid.fir (+8)
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 9ad37c8df434a2..8fa695a5c0c2e1 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -479,7 +479,10 @@ mlir::Value fir::factory::createConvert(mlir::OpBuilder &builder,
                                         mlir::Location loc, mlir::Type toTy,
                                         mlir::Value val) {
   if (val.getType() != toTy) {
-    assert(!fir::isa_derived(toTy));
+    assert((!fir::isa_derived(toTy) ||
+            mlir::cast<fir::RecordType>(val.getType()).getTypeList() ==
+                mlir::cast<fir::RecordType>(toTy).getTypeList()) &&
+           "incompatible record types");
     return builder.create<fir::ConvertOp>(loc, toTy, val);
   }
   return val;
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 1cb869bfeb95a8..19c38a1ba6be26 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -660,6 +660,31 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
     auto loc = convert.getLoc();
     auto i1Type = mlir::IntegerType::get(convert.getContext(), 1);
 
+    if (mlir::isa<fir::RecordType>(toFirTy)) {
+      // Convert to compatible BIND(C) record type.
+      // Double check that the record types are compatible (it should have
+      // already been checked by the verifier).
+      assert(mlir::cast<fir::RecordType>(fromFirTy).getTypeList() ==
+                 mlir::cast<fir::RecordType>(toFirTy).getTypeList() &&
+             "incompatible record types");
+
+      auto toStTy = mlir::cast<mlir::LLVM::LLVMStructType>(toTy);
+      mlir::Value val = rewriter.create<mlir::LLVM::UndefOp>(loc, toStTy);
+      auto indexTypeMap = toStTy.getSubelementIndexMap();
+      assert(indexTypeMap.has_value() && "invalid record type");
+
+      for (auto [attr, type] : indexTypeMap.value()) {
+        int64_t index = mlir::cast<mlir::IntegerAttr>(attr).getInt();
+        auto extVal =
+            rewriter.create<mlir::LLVM::ExtractValueOp>(loc, op0, index);
+        val =
+            rewriter.create<mlir::LLVM::InsertValueOp>(loc, val, extVal, index);
+      }
+
+      rewriter.replaceOp(convert, val);
+      return mlir::success();
+    }
+
     if (mlir::isa<fir::LogicalType>(fromFirTy) ||
         mlir::isa<fir::LogicalType>(toFirTy)) {
       // By specification fir::LogicalType value may be any number,
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 8fdc06f6fce3f5..90ce8b87605912 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1410,6 +1410,15 @@ bool fir::ConvertOp::areVectorsCompatible(mlir::Type inTy, mlir::Type outTy) {
   return true;
 }
 
+static bool areRecordsCompatible(mlir::Type inTy, mlir::Type outTy) {
+  // Both records must have the same field types.
+  // Trust frontend semantics for in-depth checks, such as if both records
+  // have the BIND(C) attribute.
+  auto inRecTy = mlir::dyn_cast<fir::RecordType>(inTy);
+  auto outRecTy = mlir::dyn_cast<fir::RecordType>(outTy);
+  return inRecTy && outRecTy && inRecTy.getTypeList() == outRecTy.getTypeList();
+}
+
 bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) {
   if (inType == outType)
     return true;
@@ -1428,7 +1437,8 @@ bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) {
          (fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) ||
          (fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType)) ||
          (fir::isPolymorphicType(inType) && mlir::isa<BoxType>(outType)) ||
-         areVectorsCompatible(inType, outType);
+         areVectorsCompatible(inType, outType) ||
+         areRecordsCompatible(inType, outType);
 }
 
 llvm::LogicalResult fir::ConvertOp::verify() {
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 0c17d7c25a8c8d..1182a0a10f218b 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -816,6 +816,31 @@ func.func @convert_complex16(%arg0 : complex<f128>) -> complex<f16> {
 
 // -----
 
+// Test `fir.convert` operation conversion between compatible fir.record types.
+
+func.func @convert_record(%arg0 : !fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>) ->
+                                  !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}> {
+    %0 = fir.convert %arg0 : (!fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>) ->
+                              !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>
+  return %0 : !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>
+}
+
+// CHECK-LABEL: func @convert_record(
+// CHECK-SAME:    %[[ARG0:.*]]: [[MOD1_REC:!llvm.struct<"_QMmod1Trec", \(i32, f64, struct<\(f32, f32\)>, array<4 x array<1 x i8>>\)>]]) ->
+// CHECK-SAME:                  [[MOD2_REC:!llvm.struct<"_QMmod2Trec", \(i32, f64, struct<\(f32, f32\)>, array<4 x array<1 x i8>>\)>]]
+// CHECK:         %{{.*}} = llvm.mlir.undef : [[MOD2_REC]]
+// CHECK-DAG:     %[[I:.*]] = llvm.extractvalue %[[ARG0]][0] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[I]], %{{.*}}[0] : [[MOD2_REC]]
+// CHECK-DAG:     %[[F:.*]] = llvm.extractvalue %[[ARG0]][1] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[F]], %{{.*}}[1] : [[MOD2_REC]]
+// CHECK-DAG:     %[[C:.*]] = llvm.extractvalue %[[ARG0]][2] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[C]], %{{.*}}[2] : [[MOD2_REC]]
+// CHECK-DAG:     %[[CSTR:.*]] = llvm.extractvalue %[[ARG0]][3] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[CSTR]], %{{.*}}[3] : [[MOD2_REC]]
+// CHECK:         llvm.return %{{.*}} : [[MOD2_REC]]
+
+// -----
+
 // Test `fir.store` --> `llvm.store` conversion
 
 func.func @test_store_index(%val_to_store : index, %addr : !fir.ref<index>) {
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index 086a426db5642e..7e3f9d64984129 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -965,6 +965,14 @@ func.func @fp_to_logical(%arg0: f32) -> !fir.logical<4> {
 
 // -----
 
+func.func @rec_to_rec(%arg0: !fir.type<t1{i:i32, f:f32}>) -> !fir.type<t2{f:f32, i:i32}> {
+  // expected-error@+1{{'fir.convert' op invalid type conversion}}
+  %0 = fir.convert %arg0 : (!fir.type<t1{i:i32, f:f32}>) -> !fir.type<t2{f:f32, i:i32}>
+  return %0 : !fir.type<t2{f:f32, i:i32}>
+}
+
+// -----
+
 func.func @bad_box_offset(%not_a_box : !fir.ref<i32>) {
   // expected-error@+1{{'fir.box_offset' op box_ref operand must have !fir.ref<!fir.box<T>> type}}
   %addr1 = fir.box_offset %not_a_box base_addr : (!fir.ref<i32>) -> !fir.llvm_ptr<!fir.ref<i32>>

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good to me, thanks

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@luporl luporl merged commit 390943f into llvm:main Oct 9, 2024
13 checks passed
@luporl luporl deleted the luporl-cvt-derived branch October 9, 2024 13:37
@luporl
Copy link
Contributor Author

luporl commented Oct 9, 2024

Thanks for the reviews.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Flang] Compilation error when type is defined in module by using derived type with the same name from another module
5 participants