Skip to content

Commit 457c4fe

Browse files
committed
[flang] Allow scalar boxed record type in intrinsic elemental lowering
Relax a bit the condition added in D144417 and allow scalar polymorphic entities and boxed scalar record type. Reviewed By: jeanPerier Differential Revision: https://reviews.llvm.org/D145058
1 parent 8d09bd6 commit 457c4fe

File tree

5 files changed

+77
-1
lines changed

5 files changed

+77
-1
lines changed

flang/include/flang/Optimizer/Dialect/FIRType.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,12 @@ bool isBoxNone(mlir::Type ty);
283283
/// e.g. !fir.box<!fir.type<derived>>
284284
bool isBoxedRecordType(mlir::Type ty);
285285

286+
/// Return true iff `ty` is a scalar boxed record type.
287+
/// e.g. !fir.box<!fir.type<derived>>
288+
/// !fir.box<!fir.heap<!fir.type<derived>>>
289+
/// !fir.class<!fir.type<derived>>
290+
bool isScalarBoxedRecordType(mlir::Type ty);
291+
286292
/// Return the nested RecordType if one if found. Return ty otherwise.
287293
mlir::Type getDerivedType(mlir::Type ty);
288294

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1708,7 +1708,7 @@ IntrinsicLibrary::genElementalCall<IntrinsicLibrary::ExtendedGenerator>(
17081708
for (const fir::ExtendedValue &arg : args) {
17091709
auto *box = arg.getBoxOf<fir::BoxValue>();
17101710
if (!arg.getUnboxed() && !arg.getCharBox() &&
1711-
!(box && fir::isPolymorphicType(fir::getBase(*box).getType())))
1711+
!(box && fir::isScalarBoxedRecordType(fir::getBase(*box).getType())))
17121712
fir::emitFatalError(loc, "nonscalar intrinsic argument");
17131713
}
17141714
if (outline)

flang/lib/Optimizer/Dialect/FIRType.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,20 @@ bool isBoxedRecordType(mlir::Type ty) {
290290
return false;
291291
}
292292

293+
bool isScalarBoxedRecordType(mlir::Type ty) {
294+
if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
295+
ty = refTy;
296+
if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>()) {
297+
if (boxTy.getEleTy().isa<fir::RecordType>())
298+
return true;
299+
if (auto heapTy = boxTy.getEleTy().dyn_cast<fir::HeapType>())
300+
return heapTy.getEleTy().isa<fir::RecordType>();
301+
if (auto ptrTy = boxTy.getEleTy().dyn_cast<fir::PointerType>())
302+
return ptrTy.getEleTy().isa<fir::RecordType>();
303+
}
304+
return false;
305+
}
306+
293307
static bool isAssumedType(mlir::Type ty) {
294308
if (auto boxTy = ty.dyn_cast<fir::BoxType>()) {
295309
if (boxTy.getEleTy().isa<mlir::NoneType>())

flang/test/Lower/polymorphic-temp.f90

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,4 +207,23 @@ subroutine test_merge_intrinsic(a, b)
207207
! CHECK: %[[SELECT:.*]] = arith.select %[[CMPI]], %[[ARG0]], %[[ARG1]] : !fir.class<!fir.type<_QMpoly_tmpTp1{a:i32}>>
208208
! CHECK: fir.call @_QMpoly_tmpPcheck_scalar(%[[SELECT]]) {{.*}} : (!fir.class<!fir.type<_QMpoly_tmpTp1{a:i32}>>) -> ()
209209

210+
subroutine test_merge_intrinsic2(a, b, i)
211+
class(p1), allocatable, intent(in) :: a
212+
type(p1), allocatable :: b
213+
integer, intent(in) :: i
214+
215+
call check_scalar(merge(a, b, i==1))
216+
end subroutine
217+
218+
219+
! CHECK-LABEL: func.func @_QMpoly_tmpPtest_merge_intrinsic2(
220+
! CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.class<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>> {fir.bindc_name = "b"}, %[[I:.*]]: !fir.ref<i32> {fir.bindc_name = "i"}) {
221+
! CHECK: %[[LOAD_A:.*]] = fir.load %[[A]] : !fir.ref<!fir.class<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>>
222+
! CHECK: %[[LOAD_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>>
223+
! CHECK: %[[LOAD_I:.*]] = fir.load %[[I]] : !fir.ref<i32>
224+
! CHECK: %[[C1:.*]] = arith.constant 1 : i32
225+
! CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[LOAD_I]], %[[C1]] : i32
226+
! CHECK: %[[B_CONV:.*]] = fir.convert %[[LOAD_B]] : (!fir.box<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>) -> !fir.class<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>
227+
! CHECK: %{{.*}} = arith.select %[[CMPI]], %[[LOAD_A]], %[[B_CONV]] : !fir.class<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>
228+
210229
end module

flang/unittests/Optimizer/FIRTypesTest.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,43 @@ TEST_F(FIRTypesTest, isBoxedRecordType) {
147147
fir::ReferenceType::get(mlir::IntegerType::get(&context, 32)))));
148148
}
149149

150+
// Test fir::isScalarBoxedRecordType from flang/Optimizer/Dialect/FIRType.h.
151+
TEST_F(FIRTypesTest, isScalarBoxedRecordType) {
152+
mlir::Type recTy = fir::RecordType::get(&context, "dt");
153+
mlir::Type seqRecTy =
154+
fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, recTy);
155+
mlir::Type ty = fir::BoxType::get(recTy);
156+
EXPECT_TRUE(fir::isScalarBoxedRecordType(ty));
157+
EXPECT_TRUE(fir::isScalarBoxedRecordType(fir::ReferenceType::get(ty)));
158+
159+
// CLASS(T), ALLOCATABLE
160+
ty = fir::ClassType::get(fir::HeapType::get(recTy));
161+
EXPECT_TRUE(fir::isScalarBoxedRecordType(ty));
162+
163+
// TYPE(T), ALLOCATABLE
164+
ty = fir::BoxType::get(fir::HeapType::get(recTy));
165+
EXPECT_TRUE(fir::isScalarBoxedRecordType(ty));
166+
167+
// TYPE(T), POINTER
168+
ty = fir::BoxType::get(fir::PointerType::get(recTy));
169+
EXPECT_TRUE(fir::isScalarBoxedRecordType(ty));
170+
171+
// CLASS(T), POINTER
172+
ty = fir::ClassType::get(fir::PointerType::get(recTy));
173+
EXPECT_TRUE(fir::isScalarBoxedRecordType(ty));
174+
175+
// TYPE(T), DIMENSION(10)
176+
ty = fir::BoxType::get(fir::SequenceType::get({10}, recTy));
177+
EXPECT_FALSE(fir::isScalarBoxedRecordType(ty));
178+
179+
// TYPE(T), DIMENSION(:)
180+
ty = fir::BoxType::get(seqRecTy);
181+
EXPECT_FALSE(fir::isScalarBoxedRecordType(ty));
182+
183+
EXPECT_FALSE(fir::isScalarBoxedRecordType(fir::BoxType::get(
184+
fir::ReferenceType::get(mlir::IntegerType::get(&context, 32)))));
185+
}
186+
150187
TEST_F(FIRTypesTest, updateTypeForUnlimitedPolymorphic) {
151188
// RecordType are not changed.
152189

0 commit comments

Comments
 (0)