Skip to content

[flang][nfc] Support volatile on ref, box, and class types #134386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
mlir::Block *getAllocaBlock();

/// Safely create a reference type to the type `eleTy`.
mlir::Type getRefType(mlir::Type eleTy);
mlir::Type getRefType(mlir::Type eleTy, bool isVolatile = false);

/// Create a sequence of `eleTy` with `rank` dimensions of unknown size.
mlir::Type getVarLenSeqTy(mlir::Type eleTy, unsigned rank = 1);
Expand Down
8 changes: 8 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIRType.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ inline bool isa_char_string(mlir::Type t) {
/// (since they may hold one), and are not considered to be unknown size.
bool isa_unknown_size_box(mlir::Type t);

/// Returns true iff `t` is a type capable of representing volatility and has
/// the volatile attribute set.
bool isa_volatile_type(mlir::Type t);

/// Returns true iff `t` is a fir.char type and has an unknown length.
inline bool characterWithDynamicLen(mlir::Type t) {
if (auto charTy = mlir::dyn_cast<fir::CharacterType>(t))
Expand Down Expand Up @@ -474,6 +478,10 @@ inline mlir::Type updateTypeForUnlimitedPolymorphic(mlir::Type ty) {
return ty;
}

/// Re-create the given type with the given volatility, if this is a type
/// that can represent volatility.
mlir::Type updateTypeWithVolatility(mlir::Type type, bool isVolatile);

/// Replace the element type of \p type by \p newElementType, preserving
/// all other layers of the type (fir.ref/ptr/heap/array/box/class).
/// If \p turnBoxIntoClass and the input is a fir.box, it will be turned into
Expand Down
29 changes: 17 additions & 12 deletions flang/include/flang/Optimizer/Dialect/FIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,24 @@ def fir_BoxType : FIR_Type<"Box", "box", [], "BaseBoxType"> {
to) whether the entity is an array, its size, or what type it has.
}];

let parameters = (ins "mlir::Type":$eleTy);
let parameters = (ins "mlir::Type":$eleTy, "bool":$isVolatile);

let skipDefaultBuilders = 1;

let builders = [
TypeBuilderWithInferredContext<(ins
"mlir::Type":$eleTy), [{
return Base::get(eleTy.getContext(), eleTy);
"mlir::Type":$eleTy, CArg<"bool", "false">:$isVolatile), [{
return Base::get(eleTy.getContext(), eleTy, isVolatile);
}]>,
];

let extraClassDeclaration = [{
mlir::Type getElementType() const { return getEleTy(); }
bool isVolatile() const { return getIsVolatile(); }
}];

let genVerifyDecl = 1;

let assemblyFormat = "`<` $eleTy `>`";
let hasCustomAssemblyFormat = 1;
}

def fir_CharacterType : FIR_Type<"Character", "char"> {
Expand Down Expand Up @@ -146,16 +146,20 @@ def fir_ClassType : FIR_Type<"Class", "class", [], "BaseBoxType"> {
is equivalent to a fir.box type with a dynamic type.
}];

let parameters = (ins "mlir::Type":$eleTy);
let parameters = (ins "mlir::Type":$eleTy, "bool":$isVolatile);

let builders = [
TypeBuilderWithInferredContext<(ins "mlir::Type":$eleTy), [{
return $_get(eleTy.getContext(), eleTy);
TypeBuilderWithInferredContext<(ins "mlir::Type":$eleTy, CArg<"bool", "false">:$isVolatile), [{
return $_get(eleTy.getContext(), eleTy, isVolatile);
}]>
];

let extraClassDeclaration = [{
bool isVolatile() const { return getIsVolatile(); }
}];

let genVerifyDecl = 1;
let assemblyFormat = "`<` $eleTy `>`";
let hasCustomAssemblyFormat = 1;
}

def fir_FieldType : FIR_Type<"Field", "field"> {
Expand Down Expand Up @@ -363,18 +367,19 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {
The type of a reference to an entity in memory.
}];

let parameters = (ins "mlir::Type":$eleTy);
let parameters = (ins "mlir::Type":$eleTy, "bool":$isVolatile);

let skipDefaultBuilders = 1;

let builders = [
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
return Base::get(elementType.getContext(), elementType);
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, CArg<"bool", "false">:$isVolatile), [{
return Base::get(elementType.getContext(), elementType, isVolatile);
}]>,
];

let extraClassDeclaration = [{
mlir::Type getElementType() const { return getEleTy(); }
bool isVolatile() const { return getIsVolatile(); }
}];

let genVerifyDecl = 1;
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
return modOp.lookupSymbol<fir::GlobalOp>(name);
}

mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy) {
mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy, bool isVolatile) {
assert(!mlir::isa<fir::ReferenceType>(eleTy) && "cannot be a reference type");
return fir::ReferenceType::get(eleTy);
return fir::ReferenceType::get(eleTy, isVolatile);
}

mlir::Type fir::FirOpBuilder::getVarLenSeqTy(mlir::Type eleTy, unsigned rank) {
Expand Down
108 changes: 101 additions & 7 deletions flang/lib/Optimizer/Dialect/FIRType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ using namespace fir;

namespace {

static llvm::StringRef getVolatileKeyword() { return "volatile"; }

static mlir::ParseResult parseOptionalCommaAndKeyword(mlir::AsmParser &parser,
mlir::StringRef keyword,
bool &parsedKeyword) {
if (!parser.parseOptionalComma()) {
if (parser.parseKeyword(keyword))
return mlir::failure();
parsedKeyword = true;
return mlir::success();
}
parsedKeyword = false;
return mlir::success();
}

template <typename TYPE>
TYPE parseIntSingleton(mlir::AsmParser &parser) {
int kind = 0;
Expand Down Expand Up @@ -215,6 +230,19 @@ mlir::Type getDerivedType(mlir::Type ty) {
.Default([](mlir::Type t) { return t; });
}

mlir::Type updateTypeWithVolatility(mlir::Type type, bool isVolatile) {
// If we already have the volatility we asked for, return the type unchanged.
if (fir::isa_volatile_type(type) == isVolatile)
return type;
return mlir::TypeSwitch<mlir::Type, mlir::Type>(type)
.Case<fir::BoxType, fir::ClassType, fir::ReferenceType>(
[&](auto ty) -> mlir::Type {
using TYPE = decltype(ty);
return TYPE::get(ty.getEleTy(), isVolatile);
})
.Default([&](mlir::Type t) -> mlir::Type { return t; });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: perhaps the default case should trigger an assertion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I thought about that, but when we convert arguments for function calls (#132486, FIRBuilder::createConvertWithVolatileCast) we want to convert values of all types to match the volatility of declared argument types, and values that can't represent volatility should pass through unchanged. If we'd rather assert in the default case, we could add a utility method like type_can_be_volatile, and if it's true for the type of a parameter, then convert the type using the method above to match volatility so we don't hit the assert. That seemed a bit more cumbersome to me - but let me know if you think differently and I'll revisit!

}

mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
.Case<fir::ReferenceType, fir::PointerType, fir::HeapType,
Expand Down Expand Up @@ -701,6 +729,13 @@ bool fir::isa_unknown_size_box(mlir::Type t) {
return false;
}

bool fir::isa_volatile_type(mlir::Type t) {
return llvm::TypeSwitch<mlir::Type, bool>(t)
.Case<fir::ReferenceType, fir::BoxType, fir::ClassType>(
[](auto t) { return t.isVolatile(); })
.Default([](mlir::Type) { return false; });
}

//===----------------------------------------------------------------------===//
// BoxProcType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -738,9 +773,31 @@ static bool cannotBePointerOrHeapElementType(mlir::Type eleTy) {
// BoxType
//===----------------------------------------------------------------------===//

// `box` `<` type (`, volatile` $volatile^)? `>`
mlir::Type fir::BoxType::parse(mlir::AsmParser &parser) {
mlir::Type eleTy;
auto location = parser.getCurrentLocation();
auto *context = parser.getContext();
bool isVolatile = false;
if (parser.parseLess() || parser.parseType(eleTy))
return {};
if (parseOptionalCommaAndKeyword(parser, getVolatileKeyword(), isVolatile))
return {};
if (parser.parseGreater())
return {};
return parser.getChecked<fir::BoxType>(location, context, eleTy, isVolatile);
}

void fir::BoxType::print(mlir::AsmPrinter &printer) const {
printer << "<" << getEleTy();
if (isVolatile())
printer << ", " << getVolatileKeyword();
printer << '>';
}

llvm::LogicalResult
fir::BoxType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type eleTy) {
mlir::Type eleTy, bool isVolatile) {
if (mlir::isa<fir::BaseBoxType>(eleTy))
return emitError() << "invalid element type\n";
// TODO
Expand Down Expand Up @@ -807,9 +864,32 @@ void fir::CharacterType::print(mlir::AsmPrinter &printer) const {
// ClassType
//===----------------------------------------------------------------------===//

// `class` `<` type (`, volatile` $volatile^)? `>`
mlir::Type fir::ClassType::parse(mlir::AsmParser &parser) {
mlir::Type eleTy;
auto location = parser.getCurrentLocation();
auto *context = parser.getContext();
bool isVolatile = false;
if (parser.parseLess() || parser.parseType(eleTy))
return {};
if (parseOptionalCommaAndKeyword(parser, getVolatileKeyword(), isVolatile))
return {};
if (parser.parseGreater())
return {};
return parser.getChecked<fir::ClassType>(location, context, eleTy,
isVolatile);
}

void fir::ClassType::print(mlir::AsmPrinter &printer) const {
printer << "<" << getEleTy();
if (isVolatile())
printer << ", " << getVolatileKeyword();
printer << '>';
}

llvm::LogicalResult
fir::ClassType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type eleTy) {
mlir::Type eleTy, bool isVolatile) {
if (mlir::isa<fir::RecordType, fir::SequenceType, fir::HeapType,
fir::PointerType, mlir::NoneType, mlir::IntegerType,
mlir::FloatType, fir::CharacterType, fir::LogicalType,
Expand Down Expand Up @@ -1057,18 +1137,32 @@ unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
// ReferenceType
//===----------------------------------------------------------------------===//

// `ref` `<` type `>`
// `ref` `<` type (`, volatile` $volatile^)? `>`
mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
return parseTypeSingleton<fir::ReferenceType>(parser);
auto location = parser.getCurrentLocation();
auto *context = parser.getContext();
mlir::Type eleTy;
bool isVolatile = false;
if (parser.parseLess() || parser.parseType(eleTy))
return {};
if (parseOptionalCommaAndKeyword(parser, getVolatileKeyword(), isVolatile))
return {};
if (parser.parseGreater())
return {};
return parser.getChecked<fir::ReferenceType>(location, context, eleTy,
isVolatile);
}

void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
printer << "<" << getEleTy() << '>';
printer << "<" << getEleTy();
if (isVolatile())
printer << ", " << getVolatileKeyword();
printer << '>';
}

llvm::LogicalResult fir::ReferenceType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type eleTy) {
llvm::function_ref<mlir::InFlightDiagnostic()> emitError, mlir::Type eleTy,
bool isVolatile) {
if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
ReferenceType, TypeDescType>(eleTy))
return emitError() << "cannot build a reference to type: " << eleTy << '\n';
Expand Down
29 changes: 27 additions & 2 deletions flang/test/Fir/invalid-types.fir
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ func.func private @box3() -> !fir.boxproc<>

// -----

// expected-error@+2 {{expected non-function type}}
// expected-error@+1 {{failed to parse fir_BoxType parameter 'eleTy' which is to be a `mlir::Type`}}
// expected-error@+1 {{expected non-function type}}
func.func private @box1() -> !fir.box<>

// -----
Expand Down Expand Up @@ -105,6 +104,11 @@ func.func private @mem3() -> !fir.ref<>

// -----

// expected-error@+1 {{expected non-function type}}
func.func private @mem3() -> !fir.ref<, volatile>

// -----

// expected-error@+1 {{expected ':'}}
func.func private @arr1() -> !fir.array<*>

Expand Down Expand Up @@ -162,3 +166,24 @@ func.func private @upe() -> !fir.class<!fir.box<i32>>

// expected-error@+1 {{invalid element type}}
func.func private @upe() -> !fir.box<!fir.class<none>>

// -----

// expected-error@+1 {{invalid element type}}
func.func private @upe() -> !fir.box<!fir.class<none>, volatile>

// -----

// expected-error@+1 {{invalid element type}}
func.func private @upe() -> !fir.class<!fir.box<i32>>

// -----

// expected-error@+1 {{invalid element type}}
func.func private @upe() -> !fir.class<!fir.box<i32>, volatile>

// -----

// expected-error@+1 {{expected non-function type}}
func.func private @upe() -> !fir.class<, volatile>

36 changes: 36 additions & 0 deletions flang/unittests/Optimizer/FIRTypesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,39 @@ TEST_F(FIRTypesTest, getTypeAsString) {
EXPECT_EQ("boxchar_c8xU",
fir::getTypeAsString(fir::BoxCharType::get(&context, 1), *kindMap));
}

TEST_F(FIRTypesTest, isVolatileType) {
mlir::Type i32 = mlir::IntegerType::get(&context, 32);

mlir::Type i32NonVolatileRef = fir::ReferenceType::get(i32);
mlir::Type i32NonVolatileBox = fir::BoxType::get(i32);
mlir::Type i32NonVolatileClass = fir::ClassType::get(i32);

// Ensure the default value is false
EXPECT_EQ(i32NonVolatileRef, fir::ReferenceType::get(i32, false));
EXPECT_EQ(i32NonVolatileBox, fir::BoxType::get(i32, false));
EXPECT_EQ(i32NonVolatileClass, fir::ClassType::get(i32, false));

EXPECT_FALSE(fir::isa_volatile_type(i32));
EXPECT_FALSE(fir::isa_volatile_type(i32NonVolatileRef));
EXPECT_FALSE(fir::isa_volatile_type(i32NonVolatileBox));
EXPECT_FALSE(fir::isa_volatile_type(i32NonVolatileClass));

// Should return the same type if it's not capable of representing volatility.
EXPECT_EQ(i32, fir::updateTypeWithVolatility(i32, true));

mlir::Type i32VolatileRef =
fir::updateTypeWithVolatility(i32NonVolatileRef, true);
mlir::Type i32VolatileBox =
fir::updateTypeWithVolatility(i32NonVolatileBox, true);
mlir::Type i32VolatileClass =
fir::updateTypeWithVolatility(i32NonVolatileClass, true);

EXPECT_TRUE(fir::isa_volatile_type(i32VolatileRef));
EXPECT_TRUE(fir::isa_volatile_type(i32VolatileBox));
EXPECT_TRUE(fir::isa_volatile_type(i32VolatileClass));

EXPECT_EQ(i32VolatileRef, fir::ReferenceType::get(i32, true));
EXPECT_EQ(i32VolatileBox, fir::BoxType::get(i32, true));
EXPECT_EQ(i32VolatileClass, fir::ClassType::get(i32, true));
}