Skip to content

Commit bbb7f01

Browse files
[flang] Fix volatile attribute propagation on allocatables (#139183)
Ensure volatility is reflected not just on the reference to an allocatable, but on the box, too. When we declare a volatile allocatable, we now get a volatile reference to a volatile box. Some related cleanups: * SELECT TYPE constructs check the selector's type for volatility when creating and designating the type used in the selecting block. * Refine the verifier for fir.convert. In general, I think it is ok to implicitly drop volatility in any ptr-to-int conversion because it means we are in codegen (and representing volatility on the LLVM ops and intrinsics) or we are calling an external function (are there any cases I'm not thinking of?) * An allocatable test that was XFAILed is now passing. Making allocatables' boxes volatile resulted in accesses of those boxes being volatile, which resolved some errors coming from the strict verifier. * I noticed a runtime function was missing the fir.runtime attribute.
1 parent 4b794c8 commit bbb7f01

File tree

7 files changed

+121
-52
lines changed

7 files changed

+121
-52
lines changed

flang/lib/Lower/Bridge.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3842,6 +3842,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
38423842
bool hasLocalScope = false;
38433843
llvm::SmallVector<const Fortran::semantics::Scope *> typeCaseScopes;
38443844

3845+
const auto selectorIsVolatile = [&selector]() {
3846+
return fir::isa_volatile_type(fir::getBase(selector).getType());
3847+
};
3848+
38453849
const auto &typeCaseList =
38463850
std::get<std::list<Fortran::parser::SelectTypeConstruct::TypeCase>>(
38473851
selectTypeConstruct.t);
@@ -3995,7 +3999,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
39953999
addrTy = fir::HeapType::get(addrTy);
39964000
if (std::holds_alternative<Fortran::parser::IntrinsicTypeSpec>(
39974001
typeSpec->u)) {
3998-
mlir::Type refTy = fir::ReferenceType::get(addrTy);
4002+
mlir::Type refTy =
4003+
fir::ReferenceType::get(addrTy, selectorIsVolatile());
39994004
if (isPointer || isAllocatable)
40004005
refTy = addrTy;
40014006
exactValue = builder->create<fir::BoxAddrOp>(
@@ -4004,7 +4009,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
40044009
typeSpec->declTypeSpec->AsIntrinsic();
40054010
if (isArray) {
40064011
mlir::Value exact = builder->create<fir::ConvertOp>(
4007-
loc, fir::BoxType::get(addrTy), fir::getBase(selector));
4012+
loc, fir::BoxType::get(addrTy, selectorIsVolatile()),
4013+
fir::getBase(selector));
40084014
addAssocEntitySymbol(selectorBox->clone(exact));
40094015
} else if (intrinsic->category() ==
40104016
Fortran::common::TypeCategory::Character) {
@@ -4019,7 +4025,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
40194025
} else if (std::holds_alternative<Fortran::parser::DerivedTypeSpec>(
40204026
typeSpec->u)) {
40214027
exactValue = builder->create<fir::ConvertOp>(
4022-
loc, fir::BoxType::get(addrTy), fir::getBase(selector));
4028+
loc, fir::BoxType::get(addrTy, selectorIsVolatile()),
4029+
fir::getBase(selector));
40234030
addAssocEntitySymbol(selectorBox->clone(exactValue));
40244031
}
40254032
} else if (std::holds_alternative<Fortran::parser::DerivedTypeSpec>(
@@ -4037,7 +4044,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
40374044
addrTy = fir::PointerType::get(addrTy);
40384045
if (isAllocatable)
40394046
addrTy = fir::HeapType::get(addrTy);
4040-
mlir::Type classTy = fir::ClassType::get(addrTy);
4047+
mlir::Type classTy =
4048+
fir::ClassType::get(addrTy, selectorIsVolatile());
40414049
if (classTy == baseTy) {
40424050
addAssocEntitySymbol(selector);
40434051
} else {

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,20 +1536,50 @@ bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) {
15361536
areRecordsCompatible(inType, outType);
15371537
}
15381538

1539+
// In general, ptrtoint-like conversions are allowed to lose volatility
1540+
// information because they are either:
1541+
//
1542+
// 1. passing an entity to an external function and there's nothing we can do
1543+
// about volatility after that happens, or
1544+
// 2. for code generation, at which point we represent volatility with
1545+
// attributes on the LLVM instructions and intrinsics.
1546+
//
1547+
// For all other cases, volatility ought to match exactly.
1548+
static mlir::LogicalResult verifyVolatility(mlir::Type inType,
1549+
mlir::Type outType) {
1550+
const bool toLLVMPointer = mlir::isa<mlir::LLVM::LLVMPointerType>(outType);
1551+
const bool toInteger = fir::isa_integer(outType);
1552+
1553+
// When converting references to classes or allocatables into boxes for
1554+
// runtime arguments, we cast away all the volatility information and pass a
1555+
// box<none>. This is allowed.
1556+
const bool isBoxNoneLike = [&]() {
1557+
if (fir::isBoxNone(outType))
1558+
return true;
1559+
if (auto referenceType = mlir::dyn_cast<fir::ReferenceType>(outType)) {
1560+
if (fir::isBoxNone(referenceType.getElementType())) {
1561+
return true;
1562+
}
1563+
}
1564+
return false;
1565+
}();
1566+
1567+
const bool isPtrToIntLike = toLLVMPointer || toInteger || isBoxNoneLike;
1568+
if (isPtrToIntLike) {
1569+
return mlir::success();
1570+
}
1571+
1572+
// In all other cases, we need to check for an exact volatility match.
1573+
return mlir::success(fir::isa_volatile_type(inType) ==
1574+
fir::isa_volatile_type(outType));
1575+
}
1576+
15391577
llvm::LogicalResult fir::ConvertOp::verify() {
15401578
mlir::Type inType = getValue().getType();
15411579
mlir::Type outType = getType();
1542-
// If we're converting to an LLVM pointer type or an integer, we don't
1543-
// need to check for volatility mismatch - volatility will be handled by the
1544-
// memory operations themselves in llvm code generation and ptr-to-int can't
1545-
// represent volatility.
1546-
const bool toLLVMPointer = mlir::isa<mlir::LLVM::LLVMPointerType>(outType);
1547-
const bool toInteger = fir::isa_integer(outType);
15481580
if (fir::useStrictVolatileVerification()) {
1549-
if (fir::isa_volatile_type(inType) != fir::isa_volatile_type(outType) &&
1550-
!toLLVMPointer && !toInteger) {
1551-
return emitOpError("cannot convert between volatile and non-volatile "
1552-
"types, use fir.volatile_cast instead ")
1581+
if (failed(verifyVolatility(inType, outType))) {
1582+
return emitOpError("this conversion does not preserve volatility: ")
15531583
<< inType << " / " << outType;
15541584
}
15551585
}

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -207,29 +207,37 @@ static bool hasExplicitLowerBounds(mlir::Value shape) {
207207
mlir::isa<fir::ShapeShiftType, fir::ShiftType>(shape.getType());
208208
}
209209

210-
static std::pair<mlir::Type, mlir::Value> updateDeclareInputTypeWithVolatility(
211-
mlir::Type inputType, mlir::Value memref, mlir::OpBuilder &builder,
212-
fir::FortranVariableFlagsAttr fortran_attrs) {
213-
if (fortran_attrs &&
214-
bitEnumContainsAny(fortran_attrs.getFlags(),
215-
fir::FortranVariableFlagsEnum::fortran_volatile)) {
216-
const bool isPointer = bitEnumContainsAny(
217-
fortran_attrs.getFlags(), fir::FortranVariableFlagsEnum::pointer);
218-
auto updateType = [&](auto t) {
219-
using FIRT = decltype(t);
220-
// A volatile pointer's pointee is volatile.
221-
auto elementType = t.getEleTy();
222-
const bool elementTypeIsVolatile =
223-
isPointer || fir::isa_volatile_type(elementType);
224-
auto newEleTy =
225-
fir::updateTypeWithVolatility(elementType, elementTypeIsVolatile);
226-
inputType = FIRT::get(newEleTy, true);
227-
};
228-
llvm::TypeSwitch<mlir::Type>(inputType)
229-
.Case<fir::ReferenceType, fir::BoxType, fir::ClassType>(updateType);
230-
memref =
231-
builder.create<fir::VolatileCastOp>(memref.getLoc(), inputType, memref);
210+
static std::pair<mlir::Type, mlir::Value>
211+
updateDeclaredInputTypeWithVolatility(mlir::Type inputType, mlir::Value memref,
212+
mlir::OpBuilder &builder,
213+
fir::FortranVariableFlagsEnum flags) {
214+
if (!bitEnumContainsAny(flags,
215+
fir::FortranVariableFlagsEnum::fortran_volatile)) {
216+
return std::make_pair(inputType, memref);
232217
}
218+
219+
// A volatile pointer's pointee is volatile.
220+
const bool isPointer =
221+
bitEnumContainsAny(flags, fir::FortranVariableFlagsEnum::pointer);
222+
// An allocatable's inner type's volatility matches that of the reference.
223+
const bool isAllocatable =
224+
bitEnumContainsAny(flags, fir::FortranVariableFlagsEnum::allocatable);
225+
226+
auto updateType = [&](auto t) {
227+
using FIRT = decltype(t);
228+
auto elementType = t.getEleTy();
229+
const bool elementTypeIsBox = mlir::isa<fir::BaseBoxType>(elementType);
230+
const bool elementTypeIsVolatile = isPointer || isAllocatable ||
231+
elementTypeIsBox ||
232+
fir::isa_volatile_type(elementType);
233+
auto newEleTy =
234+
fir::updateTypeWithVolatility(elementType, elementTypeIsVolatile);
235+
inputType = FIRT::get(newEleTy, true);
236+
};
237+
llvm::TypeSwitch<mlir::Type>(inputType)
238+
.Case<fir::ReferenceType, fir::BoxType, fir::ClassType>(updateType);
239+
memref =
240+
builder.create<fir::VolatileCastOp>(memref.getLoc(), inputType, memref);
233241
return std::make_pair(inputType, memref);
234242
}
235243

@@ -243,8 +251,11 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
243251
auto nameAttr = builder.getStringAttr(uniq_name);
244252
mlir::Type inputType = memref.getType();
245253
bool hasExplicitLbs = hasExplicitLowerBounds(shape);
246-
std::tie(inputType, memref) = updateDeclareInputTypeWithVolatility(
247-
inputType, memref, builder, fortran_attrs);
254+
if (fortran_attrs) {
255+
const auto flags = fortran_attrs.getFlags();
256+
std::tie(inputType, memref) = updateDeclaredInputTypeWithVolatility(
257+
inputType, memref, builder, flags);
258+
}
248259
mlir::Type hlfirVariableType =
249260
getHLFIRVariableType(inputType, hasExplicitLbs);
250261
build(builder, result, {hlfirVariableType, inputType}, memref, shape,

flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,10 +401,14 @@ llvm::LogicalResult SelectTypeConv::genTypeLadderStep(
401401
{
402402
// Since conversion is done in parallel for each fir.select_type
403403
// operation, the runtime function insertion must be threadsafe.
404+
auto runtimeAttr =
405+
mlir::NamedAttribute(fir::FIROpsDialect::getFirRuntimeAttrName(),
406+
mlir::UnitAttr::get(rewriter.getContext()));
404407
callee =
405408
fir::createFuncOp(rewriter.getUnknownLoc(), mod, fctName,
406409
rewriter.getFunctionType({descNoneTy, typeDescTy},
407-
rewriter.getI1Type()));
410+
rewriter.getI1Type()),
411+
{runtimeAttr});
408412
}
409413
cmp = rewriter
410414
.create<fir::CallOp>(loc, callee,

flang/test/Fir/invalid.fir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,7 +1260,7 @@ func.func @dc_invalid_reduction(%arg0: index, %arg1: index) {
12601260

12611261
// Should fail when volatility changes from a fir.convert
12621262
func.func @bad_convert_volatile(%arg0: !fir.ref<i32>) -> !fir.ref<i32, volatile> {
1263-
// expected-error@+1 {{'fir.convert' op cannot convert between volatile and non-volatile types, use fir.volatile_cast instead}}
1263+
// expected-error@+1 {{op this conversion does not preserve volatility}}
12641264
%0 = fir.convert %arg0 : (!fir.ref<i32>) -> !fir.ref<i32, volatile>
12651265
return %0 : !fir.ref<i32, volatile>
12661266
}
@@ -1269,7 +1269,7 @@ func.func @bad_convert_volatile(%arg0: !fir.ref<i32>) -> !fir.ref<i32, volatile>
12691269

12701270
// Should fail when volatility changes from a fir.convert
12711271
func.func @bad_convert_volatile2(%arg0: !fir.ref<i32, volatile>) -> !fir.ref<i32> {
1272-
// expected-error@+1 {{'fir.convert' op cannot convert between volatile and non-volatile types, use fir.volatile_cast instead}}
1272+
// expected-error@+1 {{op this conversion does not preserve volatility}}
12731273
%0 = fir.convert %arg0 : (!fir.ref<i32, volatile>) -> !fir.ref<i32>
12741274
return %0 : !fir.ref<i32>
12751275
}

0 commit comments

Comments
 (0)