Skip to content

[flang] Correctly prepare allocatable runtime call arguments #138727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "flang/Support/Fortran.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <cstdint>
#include <functional>
Expand Down Expand Up @@ -824,33 +825,23 @@ static mlir::func::FuncOp getIORuntimeFunc(mlir::Location loc,
return getRuntimeFunc<E>(loc, builder, /*isIO=*/true);
}

namespace helper {
template <int N, typename A>
void createArguments(llvm::SmallVectorImpl<mlir::Value> &result,
fir::FirOpBuilder &builder, mlir::Location loc,
mlir::FunctionType fTy, A arg) {
result.emplace_back(
builder.createConvertWithVolatileCast(loc, fTy.getInput(N), arg));
}

template <int N, typename A, typename... As>
void createArguments(llvm::SmallVectorImpl<mlir::Value> &result,
fir::FirOpBuilder &builder, mlir::Location loc,
mlir::FunctionType fTy, A arg, As... args) {
result.emplace_back(
builder.createConvertWithVolatileCast(loc, fTy.getInput(N), arg));
createArguments<N + 1>(result, builder, loc, fTy, args...);
inline llvm::SmallVector<mlir::Value>
createArguments(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::FunctionType fTy, llvm::ArrayRef<mlir::Value> args) {
return llvm::map_to_vector(llvm::zip_equal(fTy.getInputs(), args),
[&](const auto &pair) -> mlir::Value {
auto [type, argument] = pair;
return builder.createConvertWithVolatileCast(
loc, type, argument);
});
}
} // namespace helper

/// Create a SmallVector of arguments for a runtime call.
template <typename... As>
llvm::SmallVector<mlir::Value>
createArguments(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::FunctionType fTy, As... args) {
llvm::SmallVector<mlir::Value> result;
helper::createArguments<0>(result, builder, loc, fTy, args...);
return result;
return createArguments(builder, loc, fTy, {args...});
}

} // namespace fir::runtime
Expand Down
65 changes: 26 additions & 39 deletions flang/lib/Lower/Allocatable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,10 @@ static void genRuntimeSetBounds(fir::FirOpBuilder &builder, mlir::Location loc,
builder)
: fir::runtime::getRuntimeFunc<mkRTKey(AllocatableSetBounds)>(
loc, builder);
llvm::SmallVector<mlir::Value> args{box.getAddr(), dimIndex, lowerBound,
upperBound};
llvm::SmallVector<mlir::Value> operands;
for (auto [fst, snd] : llvm::zip(args, callee.getFunctionType().getInputs()))
operands.emplace_back(builder.createConvert(loc, snd, fst));
builder.create<fir::CallOp>(loc, callee, operands);
const auto args = fir::runtime::createArguments(
builder, loc, callee.getFunctionType(), box.getAddr(), dimIndex,
lowerBound, upperBound);
builder.create<fir::CallOp>(loc, callee, args);
}

/// Generate runtime call to set the lengths of a character allocatable or
Expand All @@ -162,9 +160,7 @@ static void genRuntimeInitCharacter(fir::FirOpBuilder &builder,
if (inputTypes.size() != 5)
fir::emitFatalError(
loc, "AllocatableInitCharacter runtime interface not as expected");
llvm::SmallVector<mlir::Value> args;
args.push_back(builder.createConvert(loc, inputTypes[0], box.getAddr()));
args.push_back(builder.createConvert(loc, inputTypes[1], len));
llvm::SmallVector<mlir::Value> args = {box.getAddr(), len};
if (kind == 0)
kind = mlir::cast<fir::CharacterType>(box.getEleTy()).getFKind();
args.push_back(builder.createIntegerConstant(loc, inputTypes[2], kind));
Expand All @@ -173,7 +169,9 @@ static void genRuntimeInitCharacter(fir::FirOpBuilder &builder,
// TODO: coarrays
int corank = 0;
args.push_back(builder.createIntegerConstant(loc, inputTypes[4], corank));
builder.create<fir::CallOp>(loc, callee, args);
const auto convertedArgs = fir::runtime::createArguments(
builder, loc, callee.getFunctionType(), args);
builder.create<fir::CallOp>(loc, callee, convertedArgs);
}

/// Generate a sequence of runtime calls to allocate memory.
Expand All @@ -194,10 +192,9 @@ static mlir::Value genRuntimeAllocate(fir::FirOpBuilder &builder,
args.push_back(errorManager.errMsgAddr);
args.push_back(errorManager.sourceFile);
args.push_back(errorManager.sourceLine);
llvm::SmallVector<mlir::Value> operands;
for (auto [fst, snd] : llvm::zip(args, callee.getFunctionType().getInputs()))
operands.emplace_back(builder.createConvert(loc, snd, fst));
return builder.create<fir::CallOp>(loc, callee, operands).getResult(0);
const auto convertedArgs = fir::runtime::createArguments(
builder, loc, callee.getFunctionType(), args);
return builder.create<fir::CallOp>(loc, callee, convertedArgs).getResult(0);
}

/// Generate a sequence of runtime calls to allocate memory and assign with the
Expand All @@ -213,14 +210,11 @@ static mlir::Value genRuntimeAllocateSource(fir::FirOpBuilder &builder,
loc, builder)
: fir::runtime::getRuntimeFunc<mkRTKey(AllocatableAllocateSource)>(
loc, builder);
llvm::SmallVector<mlir::Value> args{
box.getAddr(), fir::getBase(source),
errorManager.hasStat, errorManager.errMsgAddr,
errorManager.sourceFile, errorManager.sourceLine};
llvm::SmallVector<mlir::Value> operands;
for (auto [fst, snd] : llvm::zip(args, callee.getFunctionType().getInputs()))
operands.emplace_back(builder.createConvert(loc, snd, fst));
return builder.create<fir::CallOp>(loc, callee, operands).getResult(0);
const auto args = fir::runtime::createArguments(
builder, loc, callee.getFunctionType(), box.getAddr(),
fir::getBase(source), errorManager.hasStat, errorManager.errMsgAddr,
errorManager.sourceFile, errorManager.sourceLine);
return builder.create<fir::CallOp>(loc, callee, args).getResult(0);
}

/// Generate runtime call to apply mold to the descriptor.
Expand All @@ -234,14 +228,12 @@ static void genRuntimeAllocateApplyMold(fir::FirOpBuilder &builder,
builder)
: fir::runtime::getRuntimeFunc<mkRTKey(AllocatableApplyMold)>(
loc, builder);
llvm::SmallVector<mlir::Value> args{
const auto args = fir::runtime::createArguments(
builder, loc, callee.getFunctionType(),
fir::factory::getMutableIRBox(builder, loc, box), fir::getBase(mold),
builder.createIntegerConstant(
loc, callee.getFunctionType().getInputs()[2], rank)};
llvm::SmallVector<mlir::Value> operands;
for (auto [fst, snd] : llvm::zip(args, callee.getFunctionType().getInputs()))
operands.emplace_back(builder.createConvert(loc, snd, fst));
builder.create<fir::CallOp>(loc, callee, operands);
loc, callee.getFunctionType().getInputs()[2], rank));
builder.create<fir::CallOp>(loc, callee, args);
}

/// Generate a runtime call to deallocate memory.
Expand Down Expand Up @@ -669,15 +661,13 @@ class AllocateStmtHelper {

llvm::ArrayRef<mlir::Type> inputTypes =
callee.getFunctionType().getInputs();
llvm::SmallVector<mlir::Value> args;
args.push_back(builder.createConvert(loc, inputTypes[0], box.getAddr()));
args.push_back(builder.createConvert(loc, inputTypes[1], typeDescAddr));
mlir::Value rankValue =
builder.createIntegerConstant(loc, inputTypes[2], rank);
mlir::Value corankValue =
builder.createIntegerConstant(loc, inputTypes[3], corank);
args.push_back(rankValue);
args.push_back(corankValue);
const auto args = fir::runtime::createArguments(
builder, loc, callee.getFunctionType(), box.getAddr(), typeDescAddr,
rankValue, corankValue);
builder.create<fir::CallOp>(loc, callee, args);
}

Expand All @@ -696,8 +686,6 @@ class AllocateStmtHelper {

llvm::ArrayRef<mlir::Type> inputTypes =
callee.getFunctionType().getInputs();
llvm::SmallVector<mlir::Value> args;
args.push_back(builder.createConvert(loc, inputTypes[0], box.getAddr()));
mlir::Value categoryValue = builder.createIntegerConstant(
loc, inputTypes[1], static_cast<int32_t>(category));
mlir::Value kindValue =
Expand All @@ -706,10 +694,9 @@ class AllocateStmtHelper {
builder.createIntegerConstant(loc, inputTypes[3], rank);
mlir::Value corankValue =
builder.createIntegerConstant(loc, inputTypes[4], corank);
args.push_back(categoryValue);
args.push_back(kindValue);
args.push_back(rankValue);
args.push_back(corankValue);
const auto args = fir::runtime::createArguments(
builder, loc, callee.getFunctionType(), box.getAddr(), categoryValue,
kindValue, rankValue, corankValue);
builder.create<fir::CallOp>(loc, callee, args);
}

Expand Down
12 changes: 6 additions & 6 deletions flang/lib/Lower/ConvertExprToHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ class HlfirDesignatorBuilder {
isVolatile = true;
}

// Check if the base type is volatile
if (partInfo.base.has_value()) {
mlir::Type baseType = partInfo.base.value().getType();
isVolatile = isVolatile || fir::isa_volatile_type(baseType);
}

// Dynamic type of polymorphic base must be kept if the designator is
// polymorphic.
if (isPolymorphic(designatorNode))
Expand All @@ -238,12 +244,6 @@ class HlfirDesignatorBuilder {
if (charType && charType.hasDynamicLen())
return fir::BoxCharType::get(charType.getContext(), charType.getFKind());

// Check if the base type is volatile
if (partInfo.base.has_value()) {
mlir::Type baseType = partInfo.base.value().getType();
isVolatile = isVolatile || fir::isa_volatile_type(baseType);
}

// Arrays with non default lower bounds or dynamic length or dynamic extent
// need a fir.box to hold the dynamic or lower bound information.
if (fir::hasDynamicSize(resultValueType) ||
Expand Down
8 changes: 3 additions & 5 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,14 @@ static bool hasExplicitLowerBounds(mlir::Value shape) {
static std::pair<mlir::Type, mlir::Value> updateDeclareInputTypeWithVolatility(
mlir::Type inputType, mlir::Value memref, mlir::OpBuilder &builder,
fir::FortranVariableFlagsAttr fortran_attrs) {
if (mlir::isa<fir::BoxType, fir::ReferenceType>(inputType) && fortran_attrs &&
if (fortran_attrs &&
bitEnumContainsAny(fortran_attrs.getFlags(),
fir::FortranVariableFlagsEnum::fortran_volatile)) {
const bool isPointer = bitEnumContainsAny(
fortran_attrs.getFlags(), fir::FortranVariableFlagsEnum::pointer);
auto updateType = [&](auto t) {
using FIRT = decltype(t);
// If an entity is a pointer, the entity it points to is volatile, as far
// as consumers of the pointer are concerned.
// A volatile pointer's pointee is volatile.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to your patch, but since you updated the comment :), why doesn't this also apply to the target data of an allocatable descriptor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out! I discussed this with Slava, it's next on my volatile todo list :)

auto elementType = t.getEleTy();
const bool elementTypeIsVolatile =
isPointer || fir::isa_volatile_type(elementType);
Expand All @@ -227,8 +226,7 @@ static std::pair<mlir::Type, mlir::Value> updateDeclareInputTypeWithVolatility(
inputType = FIRT::get(newEleTy, true);
};
llvm::TypeSwitch<mlir::Type>(inputType)
.Case<fir::ReferenceType, fir::BoxType>(updateType)
.Default([](mlir::Type t) { return t; });
.Case<fir::ReferenceType, fir::BoxType, fir::ClassType>(updateType);
memref =
builder.create<fir::VolatileCastOp>(memref.getLoc(), inputType, memref);
}
Expand Down
Loading