Skip to content

Commit 390943f

Browse files
authored
[flang] Implement conversion of compatible derived types (#111165)
With some restrictions, BIND(C) derived types can be converted to compatible BIND(C) derived types. Semantics already support this, but ConvertOp was missing the conversion of such types. Fixes #107783
1 parent a9ebdbb commit 390943f

File tree

5 files changed

+73
-2
lines changed

5 files changed

+73
-2
lines changed

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,10 @@ mlir::Value fir::factory::createConvert(mlir::OpBuilder &builder,
479479
mlir::Location loc, mlir::Type toTy,
480480
mlir::Value val) {
481481
if (val.getType() != toTy) {
482-
assert(!fir::isa_derived(toTy));
482+
assert((!fir::isa_derived(toTy) ||
483+
mlir::cast<fir::RecordType>(val.getType()).getTypeList() ==
484+
mlir::cast<fir::RecordType>(toTy).getTypeList()) &&
485+
"incompatible record types");
483486
return builder.create<fir::ConvertOp>(loc, toTy, val);
484487
}
485488
return val;

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,31 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
660660
auto loc = convert.getLoc();
661661
auto i1Type = mlir::IntegerType::get(convert.getContext(), 1);
662662

663+
if (mlir::isa<fir::RecordType>(toFirTy)) {
664+
// Convert to compatible BIND(C) record type.
665+
// Double check that the record types are compatible (it should have
666+
// already been checked by the verifier).
667+
assert(mlir::cast<fir::RecordType>(fromFirTy).getTypeList() ==
668+
mlir::cast<fir::RecordType>(toFirTy).getTypeList() &&
669+
"incompatible record types");
670+
671+
auto toStTy = mlir::cast<mlir::LLVM::LLVMStructType>(toTy);
672+
mlir::Value val = rewriter.create<mlir::LLVM::UndefOp>(loc, toStTy);
673+
auto indexTypeMap = toStTy.getSubelementIndexMap();
674+
assert(indexTypeMap.has_value() && "invalid record type");
675+
676+
for (auto [attr, type] : indexTypeMap.value()) {
677+
int64_t index = mlir::cast<mlir::IntegerAttr>(attr).getInt();
678+
auto extVal =
679+
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, op0, index);
680+
val =
681+
rewriter.create<mlir::LLVM::InsertValueOp>(loc, val, extVal, index);
682+
}
683+
684+
rewriter.replaceOp(convert, val);
685+
return mlir::success();
686+
}
687+
663688
if (mlir::isa<fir::LogicalType>(fromFirTy) ||
664689
mlir::isa<fir::LogicalType>(toFirTy)) {
665690
// By specification fir::LogicalType value may be any number,

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1410,6 +1410,15 @@ bool fir::ConvertOp::areVectorsCompatible(mlir::Type inTy, mlir::Type outTy) {
14101410
return true;
14111411
}
14121412

1413+
static bool areRecordsCompatible(mlir::Type inTy, mlir::Type outTy) {
1414+
// Both records must have the same field types.
1415+
// Trust frontend semantics for in-depth checks, such as if both records
1416+
// have the BIND(C) attribute.
1417+
auto inRecTy = mlir::dyn_cast<fir::RecordType>(inTy);
1418+
auto outRecTy = mlir::dyn_cast<fir::RecordType>(outTy);
1419+
return inRecTy && outRecTy && inRecTy.getTypeList() == outRecTy.getTypeList();
1420+
}
1421+
14131422
bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) {
14141423
if (inType == outType)
14151424
return true;
@@ -1428,7 +1437,8 @@ bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) {
14281437
(fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) ||
14291438
(fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType)) ||
14301439
(fir::isPolymorphicType(inType) && mlir::isa<BoxType>(outType)) ||
1431-
areVectorsCompatible(inType, outType);
1440+
areVectorsCompatible(inType, outType) ||
1441+
areRecordsCompatible(inType, outType);
14321442
}
14331443

14341444
llvm::LogicalResult fir::ConvertOp::verify() {

flang/test/Fir/convert-to-llvm.fir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,31 @@ func.func @convert_complex16(%arg0 : complex<f128>) -> complex<f16> {
816816

817817
// -----
818818

819+
// Test `fir.convert` operation conversion between compatible fir.record types.
820+
821+
func.func @convert_record(%arg0 : !fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>) ->
822+
!fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}> {
823+
%0 = fir.convert %arg0 : (!fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>) ->
824+
!fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>
825+
return %0 : !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>
826+
}
827+
828+
// CHECK-LABEL: func @convert_record(
829+
// CHECK-SAME: %[[ARG0:.*]]: [[MOD1_REC:!llvm.struct<"_QMmod1Trec", \(i32, f64, struct<\(f32, f32\)>, array<4 x array<1 x i8>>\)>]]) ->
830+
// CHECK-SAME: [[MOD2_REC:!llvm.struct<"_QMmod2Trec", \(i32, f64, struct<\(f32, f32\)>, array<4 x array<1 x i8>>\)>]]
831+
// CHECK: %{{.*}} = llvm.mlir.undef : [[MOD2_REC]]
832+
// CHECK-DAG: %[[I:.*]] = llvm.extractvalue %[[ARG0]][0] : [[MOD1_REC]]
833+
// CHECK-DAG: %{{.*}} = llvm.insertvalue %[[I]], %{{.*}}[0] : [[MOD2_REC]]
834+
// CHECK-DAG: %[[F:.*]] = llvm.extractvalue %[[ARG0]][1] : [[MOD1_REC]]
835+
// CHECK-DAG: %{{.*}} = llvm.insertvalue %[[F]], %{{.*}}[1] : [[MOD2_REC]]
836+
// CHECK-DAG: %[[C:.*]] = llvm.extractvalue %[[ARG0]][2] : [[MOD1_REC]]
837+
// CHECK-DAG: %{{.*}} = llvm.insertvalue %[[C]], %{{.*}}[2] : [[MOD2_REC]]
838+
// CHECK-DAG: %[[CSTR:.*]] = llvm.extractvalue %[[ARG0]][3] : [[MOD1_REC]]
839+
// CHECK-DAG: %{{.*}} = llvm.insertvalue %[[CSTR]], %{{.*}}[3] : [[MOD2_REC]]
840+
// CHECK: llvm.return %{{.*}} : [[MOD2_REC]]
841+
842+
// -----
843+
819844
// Test `fir.store` --> `llvm.store` conversion
820845

821846
func.func @test_store_index(%val_to_store : index, %addr : !fir.ref<index>) {

flang/test/Fir/invalid.fir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,14 @@ func.func @fp_to_logical(%arg0: f32) -> !fir.logical<4> {
965965

966966
// -----
967967

968+
func.func @rec_to_rec(%arg0: !fir.type<t1{i:i32, f:f32}>) -> !fir.type<t2{f:f32, i:i32}> {
969+
// expected-error@+1{{'fir.convert' op invalid type conversion}}
970+
%0 = fir.convert %arg0 : (!fir.type<t1{i:i32, f:f32}>) -> !fir.type<t2{f:f32, i:i32}>
971+
return %0 : !fir.type<t2{f:f32, i:i32}>
972+
}
973+
974+
// -----
975+
968976
func.func @bad_box_offset(%not_a_box : !fir.ref<i32>) {
969977
// expected-error@+1{{'fir.box_offset' op box_ref operand must have !fir.ref<!fir.box<T>> type}}
970978
%addr1 = fir.box_offset %not_a_box base_addr : (!fir.ref<i32>) -> !fir.llvm_ptr<!fir.ref<i32>>

0 commit comments

Comments
 (0)