Skip to content

[flang] Added definition of hlfir.cshift operation. #118732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ mlir::Value genExprShape(mlir::OpBuilder &builder, const mlir::Location &loc,
/// This has to be cleaned up, when HLFIR is the default.
bool mayHaveAllocatableComponent(mlir::Type ty);

/// Scalar integer or a sequence of integers (via boxed array or expr).
bool isFortranIntegerScalarOrArrayObject(mlir::Type type);

} // namespace hlfir

#endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
6 changes: 6 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ def IsPolymorphicObjectPred
def AnyPolymorphicObject : Type<IsPolymorphicObjectPred,
"any polymorphic object">;

def IsFortranIntegerScalarOrArrayPred
: CPred<"::hlfir::isFortranIntegerScalarOrArrayObject($_self)">;
def AnyFortranIntegerScalarOrArrayObject
: Type<IsFortranIntegerScalarOrArrayPred,
"A scalar or array object containing integers">;

def hlfir_CharExtremumPredicateAttr : I32EnumAttr<
"CharExtremumPredicate", "",
[
Expand Down
21 changes: 21 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,27 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
let hasVerifier = 1;
}

def hlfir_CShiftOp
: hlfir_Op<
"cshift", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "CSHIFT transformational intrinsic";
let description = [{
Circular shift of an array
}];

let arguments = (ins AnyFortranArrayObject:$array,
AnyFortranIntegerScalarOrArrayObject:$shift,
Optional<AnyIntegerType>:$dim);

let results = (outs hlfir_ExprType);

let assemblyFormat = [{
$array $shift (`dim` $dim^)? attr-dict `:` functional-type(operands, results)
}];

let hasVerifier = 1;
}

// An allocation effect is needed because the value produced by the associate
// is "deallocated" by hlfir.end_associate (the end_associate must not be
// removed, and there must be only one hlfir.end_associate).
Expand Down
9 changes: 9 additions & 0 deletions flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,12 @@ mlir::Type hlfir::getExprType(mlir::Type variableType) {
return hlfir::ExprType::get(variableType.getContext(), typeShape, type,
isPolymorphic);
}

bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
if (isBoxAddressType(type))
return false;

mlir::Type unwrappedType = fir::unwrapPassByRefType(fir::unwrapRefType(type));
mlir::Type elementType = getFortranElementType(unwrappedType);
return mlir::isa<mlir::IntegerType>(elementType);
}
102 changes: 102 additions & 0 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,108 @@ void hlfir::MatmulTransposeOp::getEffects(
getIntrinsicEffects(getOperation(), effects);
}

//===----------------------------------------------------------------------===//
// CShiftOp
//===----------------------------------------------------------------------===//

llvm::LogicalResult hlfir::CShiftOp::verify() {
mlir::Value array = getArray();
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(array.getType()));
llvm::ArrayRef<int64_t> inShape = arrayTy.getShape();
std::size_t arrayRank = inShape.size();
mlir::Type eleTy = arrayTy.getEleTy();
hlfir::ExprType resultTy = mlir::cast<hlfir::ExprType>(getResult().getType());
llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
std::size_t resultRank = resultShape.size();
mlir::Type resultEleTy = resultTy.getEleTy();
mlir::Value shift = getShift();
mlir::Type shiftTy = hlfir::getFortranElementOrSequenceType(shift.getType());

if (eleTy != resultEleTy) {
if (mlir::isa<fir::CharacterType>(eleTy) &&
mlir::isa<fir::CharacterType>(resultEleTy)) {
auto eleCharTy = mlir::cast<fir::CharacterType>(eleTy);
auto resultCharTy = mlir::cast<fir::CharacterType>(resultEleTy);
if (eleCharTy.getFKind() != resultCharTy.getFKind())
return emitOpError("kind mismatch between input and output arrays");
if (eleCharTy.getLen() != fir::CharacterType::unknownLen() &&
resultCharTy.getLen() != fir::CharacterType::unknownLen() &&
eleCharTy.getLen() != resultCharTy.getLen())
return emitOpError(
"character LEN mismatch between input and output arrays");
} else {
return emitOpError(
"input and output arrays should have the same element type");
}
}

if (arrayRank != resultRank)
return emitOpError("input and output arrays should have the same rank");

constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
for (auto [inDim, resultDim] : llvm::zip(inShape, resultShape))
if (inDim != unknownExtent && resultDim != unknownExtent &&
inDim != resultDim)
return emitOpError(
"output array's shape conflicts with the input array's shape");

int64_t dimVal = -1;
if (!getDim())
dimVal = 1;
else if (auto dim = fir::getIntIfConstant(getDim()))
dimVal = *dim;

// The DIM argument may be statically invalid (e.g. exceed the
// input array rank) in dead code after constant propagation,
// so avoid some checks unless useStrictIntrinsicVerifier is true.
if (useStrictIntrinsicVerifier && dimVal != -1) {
if (dimVal < 1)
return emitOpError("DIM must be >= 1");
if (dimVal > static_cast<int64_t>(arrayRank))
return emitOpError("DIM must be <= input array's rank");
}

if (auto shiftSeqTy = mlir::dyn_cast<fir::SequenceType>(shiftTy)) {
// SHIFT is an array. Verify the rank and the shape (if DIM is constant).
llvm::ArrayRef<int64_t> shiftShape = shiftSeqTy.getShape();
std::size_t shiftRank = shiftShape.size();
if (shiftRank != arrayRank - 1)
return emitOpError(
"SHIFT's rank must be 1 less than the input array's rank");

if (useStrictIntrinsicVerifier && dimVal != -1) {
// SHIFT's shape must be [d(1), d(2), ..., d(DIM-1), d(DIM+1), ..., d(n)],
// where [d(1), d(2), ..., d(n)] is the shape of the ARRAY.
int64_t arrayDimIdx = 0;
int64_t shiftDimIdx = 0;
for (auto shiftDim : shiftShape) {
if (arrayDimIdx == dimVal - 1)
++arrayDimIdx;

if (inShape[arrayDimIdx] != unknownExtent &&
shiftDim != unknownExtent && inShape[arrayDimIdx] != shiftDim)
return emitOpError("SHAPE(ARRAY)(" + llvm::Twine(arrayDimIdx + 1) +
") must be equal to SHAPE(SHIFT)(" +
llvm::Twine(shiftDimIdx + 1) +
"): " + llvm::Twine(inShape[arrayDimIdx]) +
" != " + llvm::Twine(shiftDim));
++arrayDimIdx;
++shiftDimIdx;
}
}
}

return mlir::success();
}

void hlfir::CShiftOp::getEffects(
llvm::SmallVectorImpl<
mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
&effects) {
getIntrinsicEffects(getOperation(), effects);
}

//===----------------------------------------------------------------------===//
// AssociateOp
//===----------------------------------------------------------------------===//
Expand Down
75 changes: 75 additions & 0 deletions flang/test/HLFIR/invalid.fir
Original file line number Diff line number Diff line change
Expand Up @@ -1348,3 +1348,78 @@ func.func @bad_eval_in_mem_3() {
}
return
}

// -----

func.func @bad_cshift1(%arg0: !hlfir.expr<?x?xi32>, %arg1: i32) {
// expected-error@+1 {{'hlfir.cshift' op input and output arrays should have the same element type}}
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x?xi32>, i32) -> !hlfir.expr<?x?xf32>
return
}

// -----

func.func @bad_cshift2(%arg0: !hlfir.expr<?x?xi32>, %arg1: i32) {
// expected-error@+1 {{'hlfir.cshift' op input and output arrays should have the same rank}}
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x?xi32>, i32) -> !hlfir.expr<?xi32>
return
}

// -----

func.func @bad_cshift3(%arg0: !hlfir.expr<2x2xi32>, %arg1: i32) {
// expected-error@+1 {{'hlfir.cshift' op output array's shape conflicts with the input array's shape}}
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<2x2xi32>, i32) -> !hlfir.expr<2x3xi32>
return
}

// -----

func.func @bad_cshift4(%arg0: !hlfir.expr<2x2xi32>, %arg1: i32) {
%c0 = arith.constant 0 : index
// expected-error@+1 {{'hlfir.cshift' op DIM must be >= 1}}
%0 = hlfir.cshift %arg0 %arg1 dim %c0 : (!hlfir.expr<2x2xi32>, i32, index) -> !hlfir.expr<2x2xi32>
return
}

// -----

func.func @bad_cshift5(%arg0: !hlfir.expr<2x2xi32>, %arg1: i32) {
%c10 = arith.constant 10 : index
// expected-error@+1 {{'hlfir.cshift' op DIM must be <= input array's rank}}
%0 = hlfir.cshift %arg0 %arg1 dim %c10 : (!hlfir.expr<2x2xi32>, i32, index) -> !hlfir.expr<2x2xi32>
return
}

// -----

func.func @bad_cshift6(%arg0: !hlfir.expr<2x2xi32>, %arg1: !hlfir.expr<2x2xi32>) {
// expected-error@+1 {{'hlfir.cshift' op SHIFT's rank must be 1 less than the input array's rank}}
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32>
return
}

// -----

func.func @bad_cshift7(%arg0: !hlfir.expr<?x2xi32>, %arg1: !hlfir.expr<3xi32>) {
%c1 = arith.constant 1 : index
// expected-error@+1 {{'hlfir.cshift' op SHAPE(ARRAY)(2) must be equal to SHAPE(SHIFT)(1): 2 != 3}}
%0 = hlfir.cshift %arg0 %arg1 dim %c1 : (!hlfir.expr<?x2xi32>, !hlfir.expr<3xi32>, index) -> !hlfir.expr<2x2xi32>
return
}

// -----

func.func @bad_cshift8(%arg0: !hlfir.expr<?x!fir.char<1,?>>, %arg1: i32) {
// expected-error@+1 {{'hlfir.cshift' op kind mismatch between input and output arrays}}
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x!fir.char<1,?>>, i32) -> !hlfir.expr<?x!fir.char<2,?>>
return
}

// -----

func.func @bad_cshift9(%arg0: !hlfir.expr<?x!fir.char<1,1>>, %arg1: i32) {
// expected-error@+1 {{'hlfir.cshift' op character LEN mismatch between input and output arrays}}
%0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x!fir.char<1,1>>, i32) -> !hlfir.expr<?x!fir.char<1,2>>
return
}
Loading