Skip to content

[flang] Handle volatility in lowering and codegen #135311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/Builder/BoxValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class AbstractIrBox : public AbstractBox, public AbstractArrayBox {
auto ty = getBoxTy().getEleTy();
if (fir::isa_ref_type(ty))
return ty;
return fir::ReferenceType::get(ty);
return fir::ReferenceType::get(ty, fir::isa_volatile_type(ty));
}

/// Get the scalar type related to the described entity
Expand Down
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
mlir::Value createConvertWithVolatileCast(mlir::Location loc, mlir::Type toTy,
mlir::Value val);

/// Cast \p value to have \p isVolatile volatility.
mlir::Value createVolatileCast(mlir::Location loc, bool isVolatile,
mlir::Value value);

/// Create a fir.store of \p val into \p addr. A lazy conversion
/// of \p val to the element type of \p addr is created if needed.
void createStoreWithConvert(mlir::Location loc, mlir::Value val,
Expand Down
10 changes: 8 additions & 2 deletions flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,6 @@ Fortran::lower::genCallOpAndResult(
charFuncPointerLength = charBox->getLen();
}
}

const bool isExprCall =
converter.getLoweringOptions().getLowerToHighLevelFIR() &&
callSiteType.getNumResults() == 1 &&
Expand Down Expand Up @@ -519,7 +518,9 @@ Fortran::lower::genCallOpAndResult(
// Do not attempt any reboxing here that could break this.
bool legacyLowering =
!converter.getLoweringOptions().getLowerToHighLevelFIR();
cast = builder.convertWithSemantics(loc, snd, fst,
bool isVolatile = fir::isa_volatile_type(snd);
cast = builder.createVolatileCast(loc, isVolatile, fst);
cast = builder.convertWithSemantics(loc, snd, cast,
callingImplicitInterface,
/*allowRebox=*/legacyLowering);
}
Expand Down Expand Up @@ -1415,6 +1416,11 @@ static PreparedDummyArgument preparePresentUserCallActualArgument(
addr = hlfir::genVariableRawAddress(loc, builder, entity);
}

// If the volatility of the input type does not match the dummy type,
// we need to cast the argument.
const bool isToTypeVolatile = fir::isa_volatile_type(dummyTypeWithActualRank);
addr = builder.createVolatileCast(loc, isToTypeVolatile, addr);

// For ranked actual passed to assumed-rank dummy, the cast to assumed-rank
// box is inserted when building the fir.call op. Inserting it here would
// cause the fir.if results to be assumed-rank in case of OPTIONAL dummy,
Expand Down
17 changes: 13 additions & 4 deletions flang/lib/Lower/ConvertExprToHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,19 +212,25 @@ class HlfirDesignatorBuilder {
auto charType = mlir::dyn_cast<fir::CharacterType>(resultValueType);
if (charType && charType.hasDynamicLen())
return fir::BoxCharType::get(charType.getContext(), charType.getFKind());

// When volatile is enabled, enable volatility on the designatory type.
const bool isVolatile = false;

// Arrays with non default lower bounds or dynamic length or dynamic extent
// need a fir.box to hold the dynamic or lower bound information.
if (fir::hasDynamicSize(resultValueType) ||
mayHaveNonDefaultLowerBounds(partInfo))
return fir::BoxType::get(resultValueType);
return fir::BoxType::get(resultValueType, isVolatile);

// Non simply contiguous ref require a fir.box to carry the byte stride.
if (mlir::isa<fir::SequenceType>(resultValueType) &&
!Fortran::evaluate::IsSimplyContiguous(
designatorNode, getConverter().getFoldingContext(),
/*namedConstantSectionsAreAlwaysContiguous=*/false))
return fir::BoxType::get(resultValueType);
return fir::BoxType::get(resultValueType, isVolatile);

// Other designators can be handled as raw addresses.
return fir::ReferenceType::get(resultValueType);
return fir::ReferenceType::get(resultValueType, isVolatile);
}

template <typename T>
Expand Down Expand Up @@ -1824,7 +1830,10 @@ class HlfirBuilder {
assert(compType && "failed to retrieve component type");
mlir::Value compShape =
designatorBuilder.genComponentShape(sym, compType);
mlir::Type designatorType = builder.getRefType(compType);
const bool isDesignatorVolatile =
fir::isa_volatile_type(baseOp.getType());
mlir::Type designatorType =
builder.getRefType(compType, isDesignatorVolatile);

mlir::Type fieldElemType = hlfir::getFortranElementType(compType);
llvm::SmallVector<mlir::Value, 1> typeParams;
Expand Down
12 changes: 6 additions & 6 deletions flang/lib/Lower/HostAssociations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ class CapturedProcedure : public CapturedSymbols<CapturedProcedure> {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
assert(typeInTuple && "addrInTuple must be an address");
mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
fir::getBase(args.hostValue));
mlir::Value castBox = builder.createConvertWithVolatileCast(
args.loc, typeInTuple, fir::getBase(args.hostValue));
builder.create<fir::StoreOp>(args.loc, castBox, args.addrInTuple);
}

Expand Down Expand Up @@ -265,8 +265,8 @@ class CapturedPolymorphicScalar
mlir::Location loc = args.loc;
mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
assert(typeInTuple && "addrInTuple must be an address");
mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
fir::getBase(args.hostValue));
mlir::Value castBox = builder.createConvertWithVolatileCast(
args.loc, typeInTuple, fir::getBase(args.hostValue));
if (Fortran::semantics::IsOptional(sym)) {
auto isPresent =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), castBox);
Expand Down Expand Up @@ -329,8 +329,8 @@ class CapturedAllocatableAndPointer
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Type typeInTuple = fir::dyn_cast_ptrEleTy(args.addrInTuple.getType());
assert(typeInTuple && "addrInTuple must be an address");
mlir::Value castBox = builder.createConvert(args.loc, typeInTuple,
fir::getBase(args.hostValue));
mlir::Value castBox = builder.createConvertWithVolatileCast(
args.loc, typeInTuple, fir::getBase(args.hostValue));
builder.create<fir::StoreOp>(args.loc, castBox, args.addrInTuple);
}
static void getFromTuple(const GetFromTuple &args,
Expand Down
13 changes: 8 additions & 5 deletions flang/lib/Lower/IO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,8 @@ static void genOutputItemList(
fir::factory::CharacterExprHelper helper{builder, loc};
if (mlir::isa<fir::BoxType>(argType)) {
mlir::Value box = fir::getBase(converter.genExprBox(loc, *expr, stmtCtx));
outputFuncArgs.push_back(builder.createConvert(loc, argType, box));
outputFuncArgs.push_back(
builder.createConvertWithVolatileCast(loc, argType, box));
if (containsDerivedType(itemTy))
outputFuncArgs.push_back(getNonTbpDefinedIoTableAddr(converter));
} else if (helper.isCharacterScalar(itemTy)) {
Expand All @@ -730,9 +731,9 @@ static void genOutputItemList(
if (!exv.getCharBox())
llvm::report_fatal_error(
"internal error: scalar character not in CharBox");
outputFuncArgs.push_back(builder.createConvert(
outputFuncArgs.push_back(builder.createConvertWithVolatileCast(
loc, outputFunc.getFunctionType().getInput(1), fir::getBase(exv)));
outputFuncArgs.push_back(builder.createConvert(
outputFuncArgs.push_back(builder.createConvertWithVolatileCast(
loc, outputFunc.getFunctionType().getInput(2), fir::getLen(exv)));
} else {
fir::ExtendedValue itemBox = converter.genExprValue(loc, expr, stmtCtx);
Expand All @@ -743,7 +744,8 @@ static void genOutputItemList(
outputFuncArgs.push_back(parts.first);
outputFuncArgs.push_back(parts.second);
} else {
itemValue = builder.createConvert(loc, argType, itemValue);
itemValue =
builder.createConvertWithVolatileCast(loc, argType, itemValue);
outputFuncArgs.push_back(itemValue);
}
}
Expand Down Expand Up @@ -827,7 +829,8 @@ createIoRuntimeCallForItem(Fortran::lower::AbstractConverter &converter,
mlir::Value box = fir::getBase(item);
auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(box.getType());
assert(boxTy && "must be previously emboxed");
inputFuncArgs.push_back(builder.createConvert(loc, argType, box));
auto casted = builder.createConvertWithVolatileCast(loc, argType, box);
inputFuncArgs.push_back(casted);
if (containsDerivedType(boxTy))
inputFuncArgs.push_back(getNonTbpDefinedIoTableAddr(converter));
} else {
Expand Down
16 changes: 11 additions & 5 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,14 +577,20 @@ mlir::Value fir::FirOpBuilder::convertWithSemantics(
return createConvert(loc, toTy, val);
}

mlir::Value fir::FirOpBuilder::createVolatileCast(mlir::Location loc,
bool isVolatile,
mlir::Value val) {
mlir::Type volatileAdjustedType =
fir::updateTypeWithVolatility(val.getType(), isVolatile);
if (volatileAdjustedType == val.getType())
return val;
return create<fir::VolatileCastOp>(loc, volatileAdjustedType, val);
}

mlir::Value fir::FirOpBuilder::createConvertWithVolatileCast(mlir::Location loc,
mlir::Type toTy,
mlir::Value val) {
if (fir::isa_volatile_type(val.getType()) != fir::isa_volatile_type(toTy)) {
mlir::Type volatileAdjustedType = fir::updateTypeWithVolatility(
val.getType(), fir::isa_volatile_type(toTy));
val = create<fir::VolatileCastOp>(loc, volatileAdjustedType, val);
}
val = createVolatileCast(loc, fir::isa_volatile_type(toTy), val);
return createConvert(loc, toTy, val);
}

Expand Down
11 changes: 7 additions & 4 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,9 @@ hlfir::Entity hlfir::genVariableBox(mlir::Location loc,
mlir::Value addr = var.getBase();
if (mlir::isa<fir::BoxCharType>(var.getType()))
addr = genVariableRawAddress(loc, builder, var);
mlir::Type boxType = fir::BoxType::get(var.getElementOrSequenceType());
const bool isVolatile = fir::isa_volatile_type(var.getType());
mlir::Type boxType =
fir::BoxType::get(var.getElementOrSequenceType(), isVolatile);
if (forceBoxType) {
boxType = forceBoxType;
mlir::Type baseType =
Expand Down Expand Up @@ -793,15 +795,16 @@ mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) {
if (variable.isScalar())
return variable.getType();
mlir::Type eleTy = variable.getFortranElementType();
const bool isVolatile = fir::isa_volatile_type(variable.getType());
if (variable.isPolymorphic())
return fir::ClassType::get(eleTy);
return fir::ClassType::get(eleTy, isVolatile);
if (auto charType = mlir::dyn_cast<fir::CharacterType>(eleTy)) {
if (charType.hasDynamicLen())
return fir::BoxCharType::get(charType.getContext(), charType.getFKind());
} else if (fir::isRecordWithTypeParameters(eleTy)) {
return fir::BoxType::get(eleTy);
return fir::BoxType::get(eleTy, isVolatile);
}
return fir::ReferenceType::get(eleTy);
return fir::ReferenceType::get(eleTy, isVolatile);
}

mlir::Type hlfir::getEntityElementType(hlfir::Entity entity) {
Expand Down
12 changes: 9 additions & 3 deletions flang/lib/Optimizer/Builder/MutableBox.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ createNewFirBox(fir::FirOpBuilder &builder, mlir::Location loc,
cleanedLengths = lengths;
}
mlir::Value emptySlice;
return builder.create<fir::EmboxOp>(loc, box.getBoxTy(), cleanedAddr, shape,
auto boxType = fir::updateTypeWithVolatility(
box.getBoxTy(), fir::isa_volatile_type(cleanedAddr.getType()));
return builder.create<fir::EmboxOp>(loc, boxType, cleanedAddr, shape,
emptySlice, cleanedLengths, tdesc);
}

Expand Down Expand Up @@ -281,6 +283,9 @@ class MutablePropertyWriter {
unsigned allocator = kDefaultAllocator) {
mlir::Value irBox = createNewFirBox(builder, loc, box, addr, lbounds,
extents, lengths, tdesc);
const bool valueTypeIsVolatile =
fir::isa_volatile_type(fir::unwrapRefType(box.getAddr().getType()));
irBox = builder.createVolatileCast(loc, valueTypeIsVolatile, irBox);
builder.create<fir::StoreOp>(loc, irBox, box.getAddr());
}

Expand Down Expand Up @@ -346,7 +351,8 @@ mlir::Value fir::factory::createUnallocatedBox(
baseBoxType = baseBoxType.getBoxTypeWithNewShape(/*rank=*/0);
auto baseAddrType = baseBoxType.getEleTy();
if (!fir::isa_ref_type(baseAddrType))
baseAddrType = builder.getRefType(baseAddrType);
baseAddrType =
builder.getRefType(baseAddrType, fir::isa_volatile_type(baseBoxType));
auto type = fir::unwrapRefType(baseAddrType);
auto eleTy = fir::unwrapSequenceType(type);
if (auto recTy = mlir::dyn_cast<fir::RecordType>(eleTy))
Expand Down Expand Up @@ -516,7 +522,7 @@ void fir::factory::associateMutableBox(fir::FirOpBuilder &builder,
source.match(
[&](const fir::PolymorphicValue &p) {
mlir::Value sourceBox;
if (auto polyBox = source.getBoxOf<fir::PolymorphicValue>())
if (auto *polyBox = source.getBoxOf<fir::PolymorphicValue>())
sourceBox = polyBox->getSourceBox();
writer.updateMutableBox(p.getAddr(), /*lbounds=*/std::nullopt,
/*extents=*/std::nullopt,
Expand Down
62 changes: 42 additions & 20 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,20 @@ struct CmpcOpConversion : public fir::FIROpConversion<fir::CmpcOp> {
}
};

/// fir.volatile_cast is only useful at the fir level. Once we lower to LLVM,
/// volatility is described by setting volatile attributes on the LLVM ops.
struct VolatileCastOpConversion
: public fir::FIROpConversion<fir::VolatileCastOp> {
using FIROpConversion::FIROpConversion;

llvm::LogicalResult
matchAndRewrite(fir::VolatileCastOp volatileCast, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(volatileCast, adaptor.getOperands()[0]);
return mlir::success();
}
};

/// convert value of from-type to value of to-type
struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
using FIROpConversion::FIROpConversion;
Expand Down Expand Up @@ -3224,6 +3238,7 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
mlir::ConversionPatternRewriter &rewriter) const override {

mlir::Type llvmLoadTy = convertObjectType(load.getType());
const bool isVolatile = fir::isa_volatile_type(load.getMemref().getType());
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(load.getType())) {
// fir.box is a special case because it is considered an ssa value in
// fir, but it is lowered as a pointer to a descriptor. So
Expand Down Expand Up @@ -3253,16 +3268,17 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
mlir::Value boxSize =
computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter);
auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>(
loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false);
loc, newBoxStorage, inputBoxStorage, boxSize, isVolatile);

if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
memcpy.setTBAATags(*optionalTag);
else
attachTBAATag(memcpy, boxTy, boxTy, nullptr);
rewriter.replaceOp(load, newBoxStorage);
} else {
auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
mlir::LLVM::LoadOp loadOp = rewriter.create<mlir::LLVM::LoadOp>(
load.getLoc(), llvmLoadTy, adaptor.getOperands(), load->getAttrs());
loadOp.setVolatile_(isVolatile);
if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
loadOp.setTBAATags(*optionalTag);
else
Expand Down Expand Up @@ -3540,17 +3556,22 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
mlir::Value llvmValue = adaptor.getValue();
mlir::Value llvmMemref = adaptor.getMemref();
mlir::LLVM::AliasAnalysisOpInterface newOp;
const bool isVolatile = fir::isa_volatile_type(store.getMemref().getType());
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(storeTy)) {
mlir::Type llvmBoxTy = lowerTy().convertBoxTypeAsStruct(boxTy);
// Always use memcpy because LLVM is not as effective at optimizing
// aggregate loads/stores as it is optimizing memcpy.
TypePair boxTypePair{boxTy, llvmBoxTy};
mlir::Value boxSize =
computeBoxSize(loc, boxTypePair, llvmValue, rewriter);
newOp = rewriter.create<mlir::LLVM::MemcpyOp>(
loc, llvmMemref, llvmValue, boxSize, /*isVolatile=*/false);
newOp = rewriter.create<mlir::LLVM::MemcpyOp>(loc, llvmMemref, llvmValue,
boxSize, isVolatile);
} else {
newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
mlir::LLVM::StoreOp storeOp =
rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
if (isVolatile)
storeOp.setVolatile_(true);
newOp = storeOp;
}
if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())
newOp.setTBAATags(*optionalTag);
Expand Down Expand Up @@ -4193,21 +4214,22 @@ void fir::populateFIRToLLVMConversionPatterns(
BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion,
BoxOffsetOpConversion, BoxProcHostOpConversion, BoxRankOpConversion,
BoxTypeCodeOpConversion, BoxTypeDescOpConversion, CallOpConversion,
CmpcOpConversion, ConvertOpConversion, CoordinateOpConversion,
CopyOpConversion, DTEntryOpConversion, DeclareOpConversion,
DivcOpConversion, EmboxOpConversion, EmboxCharOpConversion,
EmboxProcOpConversion, ExtractValueOpConversion, FieldIndexOpConversion,
FirEndOpConversion, FreeMemOpConversion, GlobalLenOpConversion,
GlobalOpConversion, InsertOnRangeOpConversion, IsPresentOpConversion,
LenParamIndexOpConversion, LoadOpConversion, MulcOpConversion,
NegcOpConversion, NoReassocOpConversion, SelectCaseOpConversion,
SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion,
SliceOpConversion, StoreOpConversion, StringLitOpConversion,
SubcOpConversion, TypeDescOpConversion, TypeInfoOpConversion,
UnboxCharOpConversion, UnboxProcOpConversion, UndefOpConversion,
UnreachableOpConversion, XArrayCoorOpConversion, XEmboxOpConversion,
XReboxOpConversion, ZeroOpConversion>(converter, options);
CmpcOpConversion, VolatileCastOpConversion, ConvertOpConversion,
CoordinateOpConversion, CopyOpConversion, DTEntryOpConversion,
DeclareOpConversion, DivcOpConversion, EmboxOpConversion,
EmboxCharOpConversion, EmboxProcOpConversion, ExtractValueOpConversion,
FieldIndexOpConversion, FirEndOpConversion, FreeMemOpConversion,
GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion,
MulcOpConversion, NegcOpConversion, NoReassocOpConversion,
SelectCaseOpConversion, SelectOpConversion, SelectRankOpConversion,
SelectTypeOpConversion, ShapeOpConversion, ShapeShiftOpConversion,
ShiftOpConversion, SliceOpConversion, StoreOpConversion,
StringLitOpConversion, SubcOpConversion, TypeDescOpConversion,
TypeInfoOpConversion, UnboxCharOpConversion, UnboxProcOpConversion,
UndefOpConversion, UnreachableOpConversion, XArrayCoorOpConversion,
XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(converter,
options);

// Patterns that are populated without a type converter do not trigger
// target materializations for the operands of the root op.
Expand Down
Loading