Skip to content

Commit 1ca3927

Browse files
authored
[flang] Added definition of hlfir.cshift operation. (#118732)
CSHIFT intrinsic will be lowered to this operation, which then can be optimized as inline sequence or lowered into a runtime call.
1 parent 10f315d commit 1ca3927

File tree

6 files changed

+216
-0
lines changed

6 files changed

+216
-0
lines changed

flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ mlir::Value genExprShape(mlir::OpBuilder &builder, const mlir::Location &loc,
136136
/// This has to be cleaned up, when HLFIR is the default.
137137
bool mayHaveAllocatableComponent(mlir::Type ty);
138138

139+
/// Scalar integer or a sequence of integers (via boxed array or expr).
140+
bool isFortranIntegerScalarOrArrayObject(mlir::Type type);
141+
139142
} // namespace hlfir
140143

141144
#endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H

flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ def IsPolymorphicObjectPred
155155
def AnyPolymorphicObject : Type<IsPolymorphicObjectPred,
156156
"any polymorphic object">;
157157

158+
def IsFortranIntegerScalarOrArrayPred
159+
: CPred<"::hlfir::isFortranIntegerScalarOrArrayObject($_self)">;
160+
def AnyFortranIntegerScalarOrArrayObject
161+
: Type<IsFortranIntegerScalarOrArrayPred,
162+
"A scalar or array object containing integers">;
163+
158164
def hlfir_CharExtremumPredicateAttr : I32EnumAttr<
159165
"CharExtremumPredicate", "",
160166
[

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,27 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
699699
let hasVerifier = 1;
700700
}
701701

702+
def hlfir_CShiftOp
703+
: hlfir_Op<
704+
"cshift", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
705+
let summary = "CSHIFT transformational intrinsic";
706+
let description = [{
707+
Circular shift of an array
708+
}];
709+
710+
let arguments = (ins AnyFortranArrayObject:$array,
711+
AnyFortranIntegerScalarOrArrayObject:$shift,
712+
Optional<AnyIntegerType>:$dim);
713+
714+
let results = (outs hlfir_ExprType);
715+
716+
let assemblyFormat = [{
717+
$array $shift (`dim` $dim^)? attr-dict `:` functional-type(operands, results)
718+
}];
719+
720+
let hasVerifier = 1;
721+
}
722+
702723
// An allocation effect is needed because the value produced by the associate
703724
// is "deallocated" by hlfir.end_associate (the end_associate must not be
704725
// removed, and there must be only one hlfir.end_associate).

flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,12 @@ mlir::Type hlfir::getExprType(mlir::Type variableType) {
228228
return hlfir::ExprType::get(variableType.getContext(), typeShape, type,
229229
isPolymorphic);
230230
}
231+
232+
bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
233+
if (isBoxAddressType(type))
234+
return false;
235+
236+
mlir::Type unwrappedType = fir::unwrapPassByRefType(fir::unwrapRefType(type));
237+
mlir::Type elementType = getFortranElementType(unwrappedType);
238+
return mlir::isa<mlir::IntegerType>(elementType);
239+
}

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,108 @@ void hlfir::MatmulTransposeOp::getEffects(
13411341
getIntrinsicEffects(getOperation(), effects);
13421342
}
13431343

1344+
//===----------------------------------------------------------------------===//
1345+
// CShiftOp
1346+
//===----------------------------------------------------------------------===//
1347+
1348+
llvm::LogicalResult hlfir::CShiftOp::verify() {
1349+
mlir::Value array = getArray();
1350+
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
1351+
hlfir::getFortranElementOrSequenceType(array.getType()));
1352+
llvm::ArrayRef<int64_t> inShape = arrayTy.getShape();
1353+
std::size_t arrayRank = inShape.size();
1354+
mlir::Type eleTy = arrayTy.getEleTy();
1355+
hlfir::ExprType resultTy = mlir::cast<hlfir::ExprType>(getResult().getType());
1356+
llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
1357+
std::size_t resultRank = resultShape.size();
1358+
mlir::Type resultEleTy = resultTy.getEleTy();
1359+
mlir::Value shift = getShift();
1360+
mlir::Type shiftTy = hlfir::getFortranElementOrSequenceType(shift.getType());
1361+
1362+
if (eleTy != resultEleTy) {
1363+
if (mlir::isa<fir::CharacterType>(eleTy) &&
1364+
mlir::isa<fir::CharacterType>(resultEleTy)) {
1365+
auto eleCharTy = mlir::cast<fir::CharacterType>(eleTy);
1366+
auto resultCharTy = mlir::cast<fir::CharacterType>(resultEleTy);
1367+
if (eleCharTy.getFKind() != resultCharTy.getFKind())
1368+
return emitOpError("kind mismatch between input and output arrays");
1369+
if (eleCharTy.getLen() != fir::CharacterType::unknownLen() &&
1370+
resultCharTy.getLen() != fir::CharacterType::unknownLen() &&
1371+
eleCharTy.getLen() != resultCharTy.getLen())
1372+
return emitOpError(
1373+
"character LEN mismatch between input and output arrays");
1374+
} else {
1375+
return emitOpError(
1376+
"input and output arrays should have the same element type");
1377+
}
1378+
}
1379+
1380+
if (arrayRank != resultRank)
1381+
return emitOpError("input and output arrays should have the same rank");
1382+
1383+
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
1384+
for (auto [inDim, resultDim] : llvm::zip(inShape, resultShape))
1385+
if (inDim != unknownExtent && resultDim != unknownExtent &&
1386+
inDim != resultDim)
1387+
return emitOpError(
1388+
"output array's shape conflicts with the input array's shape");
1389+
1390+
int64_t dimVal = -1;
1391+
if (!getDim())
1392+
dimVal = 1;
1393+
else if (auto dim = fir::getIntIfConstant(getDim()))
1394+
dimVal = *dim;
1395+
1396+
// The DIM argument may be statically invalid (e.g. exceed the
1397+
// input array rank) in dead code after constant propagation,
1398+
// so avoid some checks unless useStrictIntrinsicVerifier is true.
1399+
if (useStrictIntrinsicVerifier && dimVal != -1) {
1400+
if (dimVal < 1)
1401+
return emitOpError("DIM must be >= 1");
1402+
if (dimVal > static_cast<int64_t>(arrayRank))
1403+
return emitOpError("DIM must be <= input array's rank");
1404+
}
1405+
1406+
if (auto shiftSeqTy = mlir::dyn_cast<fir::SequenceType>(shiftTy)) {
1407+
// SHIFT is an array. Verify the rank and the shape (if DIM is constant).
1408+
llvm::ArrayRef<int64_t> shiftShape = shiftSeqTy.getShape();
1409+
std::size_t shiftRank = shiftShape.size();
1410+
if (shiftRank != arrayRank - 1)
1411+
return emitOpError(
1412+
"SHIFT's rank must be 1 less than the input array's rank");
1413+
1414+
if (useStrictIntrinsicVerifier && dimVal != -1) {
1415+
// SHIFT's shape must be [d(1), d(2), ..., d(DIM-1), d(DIM+1), ..., d(n)],
1416+
// where [d(1), d(2), ..., d(n)] is the shape of the ARRAY.
1417+
int64_t arrayDimIdx = 0;
1418+
int64_t shiftDimIdx = 0;
1419+
for (auto shiftDim : shiftShape) {
1420+
if (arrayDimIdx == dimVal - 1)
1421+
++arrayDimIdx;
1422+
1423+
if (inShape[arrayDimIdx] != unknownExtent &&
1424+
shiftDim != unknownExtent && inShape[arrayDimIdx] != shiftDim)
1425+
return emitOpError("SHAPE(ARRAY)(" + llvm::Twine(arrayDimIdx + 1) +
1426+
") must be equal to SHAPE(SHIFT)(" +
1427+
llvm::Twine(shiftDimIdx + 1) +
1428+
"): " + llvm::Twine(inShape[arrayDimIdx]) +
1429+
" != " + llvm::Twine(shiftDim));
1430+
++arrayDimIdx;
1431+
++shiftDimIdx;
1432+
}
1433+
}
1434+
}
1435+
1436+
return mlir::success();
1437+
}
1438+
1439+
void hlfir::CShiftOp::getEffects(
1440+
llvm::SmallVectorImpl<
1441+
mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
1442+
&effects) {
1443+
getIntrinsicEffects(getOperation(), effects);
1444+
}
1445+
13441446
//===----------------------------------------------------------------------===//
13451447
// AssociateOp
13461448
//===----------------------------------------------------------------------===//

flang/test/HLFIR/invalid.fir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,3 +1348,78 @@ func.func @bad_eval_in_mem_3() {
13481348
}
13491349
return
13501350
}
1351+
1352+
// -----
1353+
1354+
func.func @bad_cshift1(%arg0: !hlfir.expr<?x?xi32>, %arg1: i32) {
1355+
// expected-error@+1 {{'hlfir.cshift' op input and output arrays should have the same element type}}
1356+
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x?xi32>, i32) -> !hlfir.expr<?x?xf32>
1357+
return
1358+
}
1359+
1360+
// -----
1361+
1362+
func.func @bad_cshift2(%arg0: !hlfir.expr<?x?xi32>, %arg1: i32) {
1363+
// expected-error@+1 {{'hlfir.cshift' op input and output arrays should have the same rank}}
1364+
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x?xi32>, i32) -> !hlfir.expr<?xi32>
1365+
return
1366+
}
1367+
1368+
// -----
1369+
1370+
func.func @bad_cshift3(%arg0: !hlfir.expr<2x2xi32>, %arg1: i32) {
1371+
// expected-error@+1 {{'hlfir.cshift' op output array's shape conflicts with the input array's shape}}
1372+
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<2x2xi32>, i32) -> !hlfir.expr<2x3xi32>
1373+
return
1374+
}
1375+
1376+
// -----
1377+
1378+
func.func @bad_cshift4(%arg0: !hlfir.expr<2x2xi32>, %arg1: i32) {
1379+
%c0 = arith.constant 0 : index
1380+
// expected-error@+1 {{'hlfir.cshift' op DIM must be >= 1}}
1381+
%0 = hlfir.cshift %arg0 %arg1 dim %c0 : (!hlfir.expr<2x2xi32>, i32, index) -> !hlfir.expr<2x2xi32>
1382+
return
1383+
}
1384+
1385+
// -----
1386+
1387+
func.func @bad_cshift5(%arg0: !hlfir.expr<2x2xi32>, %arg1: i32) {
1388+
%c10 = arith.constant 10 : index
1389+
// expected-error@+1 {{'hlfir.cshift' op DIM must be <= input array's rank}}
1390+
%0 = hlfir.cshift %arg0 %arg1 dim %c10 : (!hlfir.expr<2x2xi32>, i32, index) -> !hlfir.expr<2x2xi32>
1391+
return
1392+
}
1393+
1394+
// -----
1395+
1396+
func.func @bad_cshift6(%arg0: !hlfir.expr<2x2xi32>, %arg1: !hlfir.expr<2x2xi32>) {
1397+
// expected-error@+1 {{'hlfir.cshift' op SHIFT's rank must be 1 less than the input array's rank}}
1398+
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32>
1399+
return
1400+
}
1401+
1402+
// -----
1403+
1404+
func.func @bad_cshift7(%arg0: !hlfir.expr<?x2xi32>, %arg1: !hlfir.expr<3xi32>) {
1405+
%c1 = arith.constant 1 : index
1406+
// expected-error@+1 {{'hlfir.cshift' op SHAPE(ARRAY)(2) must be equal to SHAPE(SHIFT)(1): 2 != 3}}
1407+
%0 = hlfir.cshift %arg0 %arg1 dim %c1 : (!hlfir.expr<?x2xi32>, !hlfir.expr<3xi32>, index) -> !hlfir.expr<2x2xi32>
1408+
return
1409+
}
1410+
1411+
// -----
1412+
1413+
func.func @bad_cshift8(%arg0: !hlfir.expr<?x!fir.char<1,?>>, %arg1: i32) {
1414+
// expected-error@+1 {{'hlfir.cshift' op kind mismatch between input and output arrays}}
1415+
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x!fir.char<1,?>>, i32) -> !hlfir.expr<?x!fir.char<2,?>>
1416+
return
1417+
}
1418+
1419+
// -----
1420+
1421+
func.func @bad_cshift9(%arg0: !hlfir.expr<?x!fir.char<1,1>>, %arg1: i32) {
1422+
// expected-error@+1 {{'hlfir.cshift' op character LEN mismatch between input and output arrays}}
1423+
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x!fir.char<1,1>>, i32) -> !hlfir.expr<?x!fir.char<1,2>>
1424+
return
1425+
}

0 commit comments

Comments
 (0)