Skip to content

Commit 6ed9d3a

Browse files
clementvaljeanPeriermleair
committed
[flang] Lower count intrinsic
This patch adds lowering for the count intrinsic. This patch is part of the upstreaming effort from fir-dev branch. Reviewed By: jeanPerier Differential Revision: https://reviews.llvm.org/D121782 Co-authored-by: Jean Perier <[email protected]> Co-authored-by: mleair <[email protected]>
1 parent 8dca38d commit 6ed9d3a

File tree

2 files changed

+146
-36
lines changed

2 files changed

+146
-36
lines changed

flang/lib/Lower/IntrinsicCall.cpp

Lines changed: 101 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@
3838
#define PGMATH_DECLARE
3939
#include "flang/Evaluate/pgmath.h.inc"
4040

41+
/// This file implements lowering of Fortran intrinsic procedures.
42+
/// Intrinsics are lowered to a mix of FIR and MLIR operations as
43+
/// well as call to runtime functions or LLVM intrinsics.
44+
45+
/// Lowering of intrinsic procedure calls is based on a map that associates
46+
/// Fortran intrinsic generic names to FIR generator functions.
47+
/// All generator functions are member functions of the IntrinsicLibrary class
48+
/// and have the same interface.
49+
/// If no generator is given for an intrinsic name, a math runtime library
50+
/// is searched for an implementation and, if a runtime function is found,
51+
/// a call is generated for it. LLVM intrinsics are handled as a math
52+
/// runtime library here.
53+
4154
/// Enums used to templatize and share lowering of MIN and MAX.
4255
enum class Extremum { Min, Max };
4356

@@ -81,19 +94,6 @@ enum class ExtremumBehavior {
8194
// possible to implement it without some target dependent runtime.
8295
};
8396

84-
/// This file implements lowering of Fortran intrinsic procedures.
85-
/// Intrinsics are lowered to a mix of FIR and MLIR operations as
86-
/// well as call to runtime functions or LLVM intrinsics.
87-
88-
/// Lowering of intrinsic procedure calls is based on a map that associates
89-
/// Fortran intrinsic generic names to FIR generator functions.
90-
/// All generator functions are member functions of the IntrinsicLibrary class
91-
/// and have the same interface.
92-
/// If no generator is given for an intrinsic name, a math runtime library
93-
/// is searched for an implementation and, if a runtime function is found,
94-
/// a call is generated for it. LLVM intrinsics are handled as a math
95-
/// runtime library here.
96-
9797
fir::ExtendedValue Fortran::lower::getAbsentIntrinsicArgument() {
9898
return fir::UnboxedValue{};
9999
}
@@ -439,6 +439,7 @@ struct IntrinsicLibrary {
439439
fir::ExtendedValue genAssociated(mlir::Type,
440440
llvm::ArrayRef<fir::ExtendedValue>);
441441
fir::ExtendedValue genChar(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
442+
fir::ExtendedValue genCount(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
442443
mlir::Value genDim(mlir::Type, llvm::ArrayRef<mlir::Value>);
443444
fir::ExtendedValue genDotProduct(mlir::Type,
444445
llvm::ArrayRef<fir::ExtendedValue>);
@@ -592,6 +593,10 @@ static constexpr IntrinsicHandler handlers[]{
592593
{{{"pointer", asInquired}, {"target", asInquired}}},
593594
/*isElemental=*/false},
594595
{"char", &I::genChar},
596+
{"count",
597+
&I::genCount,
598+
{{{"mask", asAddr}, {"dim", asValue}, {"kind", asValue}}},
599+
/*isElemental=*/false},
595600
{"cpu_time",
596601
&I::genCpuTime,
597602
{{{"time", asAddr}}},
@@ -1644,31 +1649,64 @@ IntrinsicLibrary::genChar(mlir::Type type,
16441649
return fir::CharBoxValue{cast, len};
16451650
}
16461651

1647-
// DIM
1648-
mlir::Value IntrinsicLibrary::genDim(mlir::Type resultType,
1649-
llvm::ArrayRef<mlir::Value> args) {
1650-
assert(args.size() == 2);
1651-
if (resultType.isa<mlir::IntegerType>()) {
1652-
mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
1653-
auto diff = builder.create<mlir::arith::SubIOp>(loc, args[0], args[1]);
1654-
auto cmp = builder.create<mlir::arith::CmpIOp>(
1655-
loc, mlir::arith::CmpIPredicate::sgt, diff, zero);
1656-
return builder.create<mlir::arith::SelectOp>(loc, cmp, diff, zero);
1652+
// COUNT
1653+
fir::ExtendedValue
1654+
IntrinsicLibrary::genCount(mlir::Type resultType,
1655+
llvm::ArrayRef<fir::ExtendedValue> args) {
1656+
assert(args.size() == 3);
1657+
1658+
// Handle mask argument
1659+
fir::BoxValue mask = builder.createBox(loc, args[0]);
1660+
unsigned maskRank = mask.rank();
1661+
1662+
assert(maskRank > 0);
1663+
1664+
// Handle optional dim argument
1665+
bool absentDim = isAbsent(args[1]);
1666+
mlir::Value dim =
1667+
absentDim ? builder.createIntegerConstant(loc, builder.getIndexType(), 0)
1668+
: fir::getBase(args[1]);
1669+
1670+
if (absentDim || maskRank == 1) {
1671+
// Result is scalar if no dim argument or mask is rank 1.
1672+
// So, call specialized Count runtime routine.
1673+
return builder.createConvert(
1674+
loc, resultType,
1675+
fir::runtime::genCount(builder, loc, fir::getBase(mask), dim));
16571676
}
1658-
assert(fir::isa_real(resultType) && "Only expects real and integer in DIM");
1659-
mlir::Value zero = builder.createRealZeroConstant(loc, resultType);
1660-
auto diff = builder.create<mlir::arith::SubFOp>(loc, args[0], args[1]);
1661-
auto cmp = builder.create<mlir::arith::CmpFOp>(
1662-
loc, mlir::arith::CmpFPredicate::OGT, diff, zero);
1663-
return builder.create<mlir::arith::SelectOp>(loc, cmp, diff, zero);
1664-
}
16651677

1666-
// DOT_PRODUCT
1667-
fir::ExtendedValue
1668-
IntrinsicLibrary::genDotProduct(mlir::Type resultType,
1669-
llvm::ArrayRef<fir::ExtendedValue> args) {
1670-
return genDotProd(fir::runtime::genDotProduct, resultType, builder, loc,
1671-
stmtCtx, args);
1678+
// Call general CountDim runtime routine.
1679+
1680+
// Handle optional kind argument
1681+
bool absentKind = isAbsent(args[2]);
1682+
mlir::Value kind = absentKind ? builder.createIntegerConstant(
1683+
loc, builder.getIndexType(),
1684+
builder.getKindMap().defaultIntegerKind())
1685+
: fir::getBase(args[2]);
1686+
1687+
// Create mutable fir.box to be passed to the runtime for the result.
1688+
mlir::Type type = builder.getVarLenSeqTy(resultType, maskRank - 1);
1689+
fir::MutableBoxValue resultMutableBox =
1690+
fir::factory::createTempMutableBox(builder, loc, type);
1691+
1692+
mlir::Value resultIrBox =
1693+
fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
1694+
1695+
fir::runtime::genCountDim(builder, loc, resultIrBox, fir::getBase(mask), dim,
1696+
kind);
1697+
1698+
// Handle cleanup of allocatable result descriptor and return
1699+
fir::ExtendedValue res =
1700+
fir::factory::genMutableBoxRead(builder, loc, resultMutableBox);
1701+
return res.match(
1702+
[&](const fir::ArrayBoxValue &box) -> fir::ExtendedValue {
1703+
// Add cleanup code
1704+
addCleanUpForTemp(loc, box.getAddr());
1705+
return box;
1706+
},
1707+
[&](const auto &) -> fir::ExtendedValue {
1708+
fir::emitFatalError(loc, "unexpected result for COUNT");
1709+
});
16721710
}
16731711

16741712
// CPU_TIME
@@ -1699,6 +1737,33 @@ void IntrinsicLibrary::genDateAndTime(llvm::ArrayRef<fir::ExtendedValue> args) {
16991737
charArgs[2], values);
17001738
}
17011739

1740+
// DIM
1741+
mlir::Value IntrinsicLibrary::genDim(mlir::Type resultType,
1742+
llvm::ArrayRef<mlir::Value> args) {
1743+
assert(args.size() == 2);
1744+
if (resultType.isa<mlir::IntegerType>()) {
1745+
mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
1746+
auto diff = builder.create<mlir::arith::SubIOp>(loc, args[0], args[1]);
1747+
auto cmp = builder.create<mlir::arith::CmpIOp>(
1748+
loc, mlir::arith::CmpIPredicate::sgt, diff, zero);
1749+
return builder.create<mlir::arith::SelectOp>(loc, cmp, diff, zero);
1750+
}
1751+
assert(fir::isa_real(resultType) && "Only expects real and integer in DIM");
1752+
mlir::Value zero = builder.createRealZeroConstant(loc, resultType);
1753+
auto diff = builder.create<mlir::arith::SubFOp>(loc, args[0], args[1]);
1754+
auto cmp = builder.create<mlir::arith::CmpFOp>(
1755+
loc, mlir::arith::CmpFPredicate::OGT, diff, zero);
1756+
return builder.create<mlir::arith::SelectOp>(loc, cmp, diff, zero);
1757+
}
1758+
1759+
// DOT_PRODUCT
1760+
fir::ExtendedValue
1761+
IntrinsicLibrary::genDotProduct(mlir::Type resultType,
1762+
llvm::ArrayRef<fir::ExtendedValue> args) {
1763+
return genDotProd(fir::runtime::genDotProduct, resultType, builder, loc,
1764+
stmtCtx, args);
1765+
}
1766+
17021767
// IAND
17031768
mlir::Value IntrinsicLibrary::genIand(mlir::Type resultType,
17041769
llvm::ArrayRef<mlir::Value> args) {

flang/test/Lower/Intrinsics/count.f90

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
! RUN: bbc -emit-fir %s -o - | FileCheck %s
2+
3+
! CHECK-LABEL: count_test1
4+
! CHECK-SAME: %[[arg0:.*]]: !fir.ref<i32>{{.*}}, %[[arg1:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>{{.*}})
5+
subroutine count_test1(rslt, mask)
6+
integer :: rslt
7+
logical :: mask(:)
8+
! CHECK-DAG: %[[c1:.*]] = arith.constant 0 : index
9+
! CHECK-DAG: %[[a2:.*]] = fir.convert %[[arg1]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
10+
! CHECK: %[[a4:.*]] = fir.convert %[[c1]] : (index) -> i32
11+
rslt = count(mask)
12+
! CHECK: %[[a5:.*]] = fir.call @_FortranACount(%[[a2]], %{{.*}}, %{{.*}}, %[[a4]]) : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> i64
13+
end subroutine
14+
15+
! CHECK-LABEL: test_count2
16+
! CHECK-SAME: %[[arg0:.*]]: !fir.box<!fir.array<?xi32>>{{.*}}, %[[arg1:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>{{.*}})
17+
subroutine test_count2(rslt, mask)
18+
integer :: rslt(:)
19+
logical :: mask(:,:)
20+
! CHECK-DAG: %[[c1_i32:.*]] = arith.constant 1 : i32
21+
! CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
22+
! CHECK-DAG: %[[a0:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xi32>>>
23+
! CHECK: %[[a5:.*]] = fir.convert %[[a0]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
24+
! CHECK: %[[a6:.*]] = fir.convert %[[arg1]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.box<none>
25+
! CHECK: %[[a7:.*]] = fir.convert %[[c4]] : (index) -> i32
26+
rslt = count(mask, dim=1)
27+
! CHECK: %{{.*}} = fir.call @_FortranACountDim(%[[a5]], %[[a6]], %[[c1_i32]], %[[a7]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none
28+
! CHECK: %[[a10:.*]] = fir.load %[[a0]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
29+
! CHECK: %[[a12:.*]] = fir.box_addr %[[a10]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
30+
! CHECK: fir.freemem %[[a12]]
31+
end subroutine
32+
33+
! CHECK-LABEL: test_count3
34+
! CHECK-SAME: %[[arg0:.*]]: !fir.ref<i32>{{.*}}, %[[arg1:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>{{.*}})
35+
subroutine test_count3(rslt, mask)
36+
integer :: rslt
37+
logical :: mask(:)
38+
! CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
39+
! CHECK-DAG: %[[a1:.*]] = fir.convert %[[arg1]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
40+
! CHECK: %[[a3:.*]] = fir.convert %[[c0]] : (index) -> i32
41+
call bar(count(mask, kind=2))
42+
! CHECK: %[[a4:.*]] = fir.call @_FortranACount(%[[a1]], %{{.*}}, %{{.*}}, %[[a3]]) : (!fir.box<none>, !fir.ref<i8>, i32, i32) -> i64
43+
! CHECK: %{{.*}} = fir.convert %[[a4]] : (i64) -> i16
44+
end subroutine
45+

0 commit comments

Comments
 (0)