Skip to content

[flang] Improve runtime SAME_TYPE_AS() #135670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion flang-rt/include/flang-rt/runtime/type-info.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,13 @@ class DerivedType {
}
RT_API_ATTRS const Descriptor &name() const { return name_.descriptor(); }
RT_API_ATTRS std::uint64_t sizeInBytes() const { return sizeInBytes_; }
RT_API_ATTRS const Descriptor &uninstatiated() const {
RT_API_ATTRS const Descriptor &uninstantiated() const {
return uninstantiated_.descriptor();
}
RT_API_ATTRS const DerivedType *uninstantiatedType() const {
return reinterpret_cast<const DerivedType *>(
uninstantiated().raw().base_addr);
}
RT_API_ATTRS const Descriptor &kindParameter() const {
return kindParameter_.descriptor();
}
Expand Down
50 changes: 14 additions & 36 deletions flang-rt/lib/runtime/derived-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,27 +83,6 @@ bool RTDEF(ClassIs)(
return false;
}

static RT_API_ATTRS bool CompareDerivedTypeNames(
const Descriptor &a, const Descriptor &b) {
if (a.raw().version == CFI_VERSION &&
a.type() == TypeCode{TypeCategory::Character, 1} &&
a.ElementBytes() > 0 && a.rank() == 0 && a.OffsetElement() != nullptr &&
a.raw().version == CFI_VERSION &&
b.type() == TypeCode{TypeCategory::Character, 1} &&
b.ElementBytes() > 0 && b.rank() == 0 && b.OffsetElement() != nullptr &&
a.ElementBytes() == b.ElementBytes() &&
Fortran::runtime::memcmp(
a.OffsetElement(), b.OffsetElement(), a.ElementBytes()) == 0) {
return true;
}
return false;
}

inline RT_API_ATTRS bool CompareDerivedType(
const typeInfo::DerivedType *a, const typeInfo::DerivedType *b) {
return a == b || CompareDerivedTypeNames(a->name(), b->name());
}

static RT_API_ATTRS const typeInfo::DerivedType *GetDerivedType(
const Descriptor &desc) {
if (const DescriptorAddendum * addendum{desc.Addendum()}) {
Expand All @@ -121,22 +100,21 @@ bool RTDEF(SameTypeAs)(const Descriptor &a, const Descriptor &b) {
(bType != CFI_type_struct && bType != CFI_type_other)) {
// If either type is intrinsic, they must match.
return aType == bType;
} else {
const typeInfo::DerivedType *derivedTypeA{GetDerivedType(a)};
const typeInfo::DerivedType *derivedTypeB{GetDerivedType(b)};
if (derivedTypeA == nullptr || derivedTypeB == nullptr) {
// Unallocated/disassociated CLASS(*) never matches.
return false;
} else if (derivedTypeA == derivedTypeB) {
// Exact match of derived type.
return true;
} else {
// Otherwise compare with the name. Note 16.29 kind type parameters are
// not considered in the test.
return CompareDerivedTypeNames(
derivedTypeA->name(), derivedTypeB->name());
} else if (const typeInfo::DerivedType * derivedTypeA{GetDerivedType(a)}) {
if (const typeInfo::DerivedType * derivedTypeB{GetDerivedType(b)}) {
if (derivedTypeA == derivedTypeB) {
return true;
} else if (const typeInfo::DerivedType *
uninstDerivedTypeA{derivedTypeA->uninstantiatedType()}) {
// There are KIND type parameters, are these the same type if those
// are ignored?
const typeInfo::DerivedType *uninstDerivedTypeB{
derivedTypeB->uninstantiatedType()};
return uninstDerivedTypeA == uninstDerivedTypeB;
}
}
}
return false;
}

bool RTDEF(ExtendsTypeOf)(const Descriptor &a, const Descriptor &mold) {
Expand All @@ -155,7 +133,7 @@ bool RTDEF(ExtendsTypeOf)(const Descriptor &a, const Descriptor &mold) {
// dynamic type of MOLD.
for (const typeInfo::DerivedType *derivedTypeA{GetDerivedType(a)};
derivedTypeA; derivedTypeA = derivedTypeA->GetParentType()) {
if (CompareDerivedType(derivedTypeA, derivedTypeMold)) {
if (derivedTypeA == derivedTypeMold) {
return true;
}
}
Expand Down
23 changes: 14 additions & 9 deletions flang/lib/Semantics/runtime-type-info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class RuntimeTableBuilder {
void DescribeTypes(Scope &scope, bool inSchemata);

private:
const Symbol *DescribeType(Scope &);
const Symbol *DescribeType(Scope &, bool wantUninstantiatedPDT);
const Symbol &GetSchemaSymbol(const char *) const;
const DeclTypeSpec &GetSchema(const char *) const;
SomeExpr GetEnumValue(const char *) const;
Expand Down Expand Up @@ -238,7 +238,7 @@ void RuntimeTableBuilder::DescribeTypes(Scope &scope, bool inSchemata) {
inSchemata |= ignoreScopes_.find(&scope) != ignoreScopes_.end();
if (scope.IsDerivedType()) {
if (!inSchemata) { // don't loop trying to describe a schema
DescribeType(scope);
DescribeType(scope, /*wantUninstantiatedPDT=*/false);
}
} else {
scope.InstantiateDerivedTypes();
Expand Down Expand Up @@ -310,10 +310,10 @@ static SomeExpr StructureExpr(evaluate::StructureConstructor &&x) {
return SomeExpr{evaluate::Expr<evaluate::SomeDerived>{std::move(x)}};
}

static int GetIntegerKind(const Symbol &symbol) {
static int GetIntegerKind(const Symbol &symbol, bool canBeUninstantiated) {
auto dyType{evaluate::DynamicType::From(symbol)};
CHECK((dyType && dyType->category() == TypeCategory::Integer) ||
symbol.owner().context().HasError(symbol));
symbol.owner().context().HasError(symbol) || canBeUninstantiated);
return dyType && dyType->category() == TypeCategory::Integer
? dyType->kind()
: symbol.owner().context().GetDefaultKind(TypeCategory::Integer);
Expand Down Expand Up @@ -395,7 +395,8 @@ static std::optional<std::string> GetSuffixIfTypeKindParameters(
return std::nullopt;
}

const Symbol *RuntimeTableBuilder::DescribeType(Scope &dtScope) {
const Symbol *RuntimeTableBuilder::DescribeType(
Scope &dtScope, bool wantUninstantiatedPDT) {
if (const Symbol * info{dtScope.runtimeDerivedTypeDescription()}) {
return info;
}
Expand Down Expand Up @@ -449,7 +450,7 @@ const Symbol *RuntimeTableBuilder::DescribeType(Scope &dtScope) {
GetSuffixIfTypeKindParameters(*derivedTypeSpec, parameters)}) {
distinctName += *suffix;
}
} else if (isPDTDefinitionWithKindParameters) {
} else if (isPDTDefinitionWithKindParameters && !wantUninstantiatedPDT) {
return nullptr;
}
std::string dtDescName{(fir::kTypeDescriptorSeparator + distinctName).str()};
Expand Down Expand Up @@ -480,7 +481,8 @@ const Symbol *RuntimeTableBuilder::DescribeType(Scope &dtScope) {
}
if (const Symbol *
uninstDescObject{isPDTInstantiation
? DescribeType(DEREF(const_cast<Scope *>(dtSymbol->scope())))
? DescribeType(DEREF(const_cast<Scope *>(dtSymbol->scope())),
/*wantUninstantiatedPDT=*/true)
: nullptr}) {
AddValue(dtValues, derivedTypeSchema_, "uninstantiated"s,
evaluate::AsGenericExpr(evaluate::Expr<evaluate::SomeDerived>{
Expand Down Expand Up @@ -516,7 +518,8 @@ const Symbol *RuntimeTableBuilder::DescribeType(Scope &dtScope) {
}
kinds.emplace_back(value);
} else { // LEN= parameter
lenKinds.emplace_back(GetIntegerKind(*inst));
lenKinds.emplace_back(
GetIntegerKind(*inst, isPDTDefinitionWithKindParameters));
}
}
}
Expand Down Expand Up @@ -804,7 +807,9 @@ evaluate::StructureConstructor RuntimeTableBuilder::DescribeComponent(
const DerivedTypeSpec &spec{dyType.GetDerivedTypeSpec()};
Scope *derivedScope{const_cast<Scope *>(
spec.scope() ? spec.scope() : spec.typeSymbol().scope())};
if (const Symbol * derivedDescription{DescribeType(DEREF(derivedScope))}) {
if (const Symbol *
derivedDescription{DescribeType(
DEREF(derivedScope), /*wantUninstantiatedPDT=*/false)}) {
AddValue(values, componentSchema_, "derived"s,
evaluate::AsGenericExpr(evaluate::Expr<evaluate::SomeDerived>{
evaluate::Designator<evaluate::SomeDerived>{
Expand Down
3 changes: 2 additions & 1 deletion flang/test/Semantics/typeinfo01.f90
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ module m03
end type
type(kpdt(4)) :: x
!CHECK: .c.kpdt.4, SAVE, TARGET (CompilerCreated, ReadOnly): ObjectEntity type: TYPE(component) shape: 0_8:0_8 init:[component::component(name=.n.a,genre=1_1,category=2_1,kind=4_1,rank=0_1,offset=0_8,characterlen=value(genre=1_1,value=0_8),derived=NULL(),lenvalue=NULL(),bounds=NULL(),initialization=NULL())]
!CHECK: .dt.kpdt.4, SAVE, TARGET (CompilerCreated, ReadOnly): ObjectEntity type: TYPE(derivedtype) init:derivedtype(binding=NULL(),name=.n.kpdt,sizeinbytes=4_8,uninstantiated=NULL(),kindparameter=.kp.kpdt.4,lenparameterkind=NULL(),component=.c.kpdt.4,procptr=NULL(),special=NULL(),specialbitset=0_4,hasparent=0_1,noinitializationneeded=1_1,nodestructionneeded=1_1,nofinalizationneeded=1_1)
!CHECK: .dt.kpdt, SAVE, TARGET (CompilerCreated, ReadOnly): ObjectEntity type: TYPE(derivedtype) init:derivedtype(name=.n.kpdt,uninstantiated=NULL(),kindparameter=.kp.kpdt,lenparameterkind=NULL())
!CHECK: .dt.kpdt.4, SAVE, TARGET (CompilerCreated, ReadOnly): ObjectEntity type: TYPE(derivedtype) init:derivedtype(binding=NULL(),name=.n.kpdt,sizeinbytes=4_8,uninstantiated=.dt.kpdt,kindparameter=.kp.kpdt.4,lenparameterkind=NULL(),component=.c.kpdt.4,procptr=NULL(),special=NULL(),specialbitset=0_4,hasparent=0_1,noinitializationneeded=1_1,nodestructionneeded=1_1,nofinalizationneeded=1_1)
!CHECK: .kp.kpdt.4, SAVE, TARGET (CompilerCreated, ReadOnly): ObjectEntity type: INTEGER(8) shape: 0_8:0_8 init:[INTEGER(8)::4_8]
end module

Expand Down
Loading