Skip to content

[flang][acc] Implement type categorization for FIR types #126964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions flang/include/flang/Optimizer/OpenACC/FIROpenACCTypeInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@

namespace fir::acc {

template <typename T>
struct OpenACCPointerLikeModel
: public mlir::acc::PointerLikeType::ExternalModel<
OpenACCPointerLikeModel<T>, T> {
mlir::Type getElementType(mlir::Type pointer) const {
return mlir::cast<T>(pointer).getElementType();
}
mlir::acc::VariableTypeCategory
getPointeeTypeCategory(mlir::Type pointer,
mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
mlir::Type varType) const;
};

template <typename T>
struct OpenACCMappableModel
: public mlir::acc::MappableType::ExternalModel<OpenACCMappableModel<T>,
Expand All @@ -36,6 +49,9 @@ struct OpenACCMappableModel
llvm::SmallVector<mlir::Value>
generateAccBounds(mlir::Type type, mlir::Value var,
mlir::OpBuilder &builder) const;

mlir::acc::VariableTypeCategory getTypeCategory(mlir::Type type,
mlir::Value var) const;
};

} // namespace fir::acc
Expand Down
10 changes: 0 additions & 10 deletions flang/include/flang/Tools/PointerModels.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#ifndef FORTRAN_TOOLS_POINTER_MODELS_H
#define FORTRAN_TOOLS_POINTER_MODELS_H

#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"

/// models for FIR pointer like types that already provide a `getElementType`
Expand All @@ -24,13 +23,4 @@ struct OpenMPPointerLikeModel
}
};

template <typename T>
struct OpenACCPointerLikeModel
: public mlir::acc::PointerLikeType::ExternalModel<
OpenACCPointerLikeModel<T>, T> {
mlir::Type getElementType(mlir::Type pointer) const {
return mlir::cast<T>(pointer).getElementType();
}
};

#endif // FORTRAN_TOOLS_POINTER_MODELS_H
8 changes: 4 additions & 4 deletions flang/lib/Frontend/FrontendActions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,12 @@ bool CodeGenAction::beginSourceFileAction() {
}

// Load the MLIR dialects required by Flang
mlir::DialectRegistry registry;
mlirCtx = std::make_unique<mlir::MLIRContext>(registry);
fir::support::registerNonCodegenDialects(registry);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed it because the call below loadDialects includes that logic. Namely, loadDialects, calls registerDialects, which calls registerNonCodegenDialects.

I can also remove line below which calls loadNonCodegenDialects since call to loadDialects calls loadDialect<FLANG_DIALECT_LIST> which includes the NonCodegenDialects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did remove the line below. Might as well have it cleaned up since I am touching this. I hope you agree with how it looks now. If not I can add them back despite them being extraneous.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine. Sorry I didn't really look into the logic behind other calls.

fir::support::loadNonCodegenDialects(*mlirCtx);
mlirCtx = std::make_unique<mlir::MLIRContext>();
fir::support::loadDialects(*mlirCtx);
fir::support::registerLLVMTranslation(*mlirCtx);
mlir::DialectRegistry registry;
fir::acc::registerOpenACCExtensions(registry);
mlirCtx->appendDialectRegistry(registry);

const llvm::TargetMachine &targetMachine = ci.getTargetMachine();

Expand Down
11 changes: 0 additions & 11 deletions flang/lib/Optimizer/Dialect/FIRType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1370,23 +1370,12 @@ void FIROpsDialect::registerTypes() {
TypeDescType, fir::VectorType, fir::DummyScopeType>();
fir::ReferenceType::attachInterface<
OpenMPPointerLikeModel<fir::ReferenceType>>(*getContext());
fir::ReferenceType::attachInterface<
OpenACCPointerLikeModel<fir::ReferenceType>>(*getContext());

fir::PointerType::attachInterface<OpenMPPointerLikeModel<fir::PointerType>>(
*getContext());
fir::PointerType::attachInterface<OpenACCPointerLikeModel<fir::PointerType>>(
*getContext());

fir::HeapType::attachInterface<OpenMPPointerLikeModel<fir::HeapType>>(
*getContext());
fir::HeapType::attachInterface<OpenACCPointerLikeModel<fir::HeapType>>(
*getContext());

fir::LLVMPointerType::attachInterface<
OpenMPPointerLikeModel<fir::LLVMPointerType>>(*getContext());
fir::LLVMPointerType::attachInterface<
OpenACCPointerLikeModel<fir::LLVMPointerType>>(*getContext());
}

std::optional<std::pair<uint64_t, unsigned short>>
Expand Down
2 changes: 2 additions & 0 deletions flang/lib/Optimizer/OpenACC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_flang_library(FIROpenACCSupport

DEPENDS
FIRBuilder
FIRCodeGen
FIRDialect
FIRDialectSupport
FIRSupport
Expand All @@ -14,6 +15,7 @@ add_flang_library(FIROpenACCSupport

LINK_LIBS
FIRBuilder
FIRCodeGen
FIRDialect
FIRDialectSupport
FIRSupport
Expand Down
143 changes: 143 additions & 0 deletions flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "flang/Optimizer/Builder/DirectivesCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/CodeGen/CGOps.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
Expand All @@ -24,6 +25,7 @@
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"

namespace fir::acc {

Expand Down Expand Up @@ -224,4 +226,145 @@ OpenACCMappableModel<fir::BaseBoxType>::generateAccBounds(
return {};
}

static bool isScalarLike(mlir::Type type) {
return fir::isa_trivial(type) || fir::isa_ref_type(type);
}

static bool isArrayLike(mlir::Type type) {
return mlir::isa<fir::SequenceType>(type);
}

static bool isCompositeLike(mlir::Type type) {
return mlir::isa<fir::RecordType, fir::ClassType, mlir::TupleType>(type);
}

template <>
mlir::acc::VariableTypeCategory
OpenACCMappableModel<fir::SequenceType>::getTypeCategory(
mlir::Type type, mlir::Value var) const {
return mlir::acc::VariableTypeCategory::array;
}

template <>
mlir::acc::VariableTypeCategory
OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
mlir::Value var) const {

mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(type);

// If the type enclosed by the box is a mappable type, then have it
// provide the type category.
if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
return mappableTy.getTypeCategory(var);

// For all arrays, despite whether they are allocatable, pointer, assumed,
// etc, we'd like to categorize them as "array".
if (isArrayLike(eleTy))
return mlir::acc::VariableTypeCategory::array;

// We got here because we don't have an array nor a mappable type. At this
// point, we know we have a type that fits the "aggregate" definition since it
// is a type with a descriptor. Try to refine it by checking if it matches the
// "composite" definition.
if (isCompositeLike(eleTy))
return mlir::acc::VariableTypeCategory::composite;

// Even if we have a scalar type - simply because it is wrapped in a box
// we want to categorize it as "nonscalar". Anything else would've been
// non-scalar anyway.
return mlir::acc::VariableTypeCategory::nonscalar;
}

static mlir::TypedValue<mlir::acc::PointerLikeType>
getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
// If there is no defining op - the unwrapped reference is the base one.
mlir::Operation *op = varPtr.getDefiningOp();
if (!op)
return varPtr;

// Look to find if this value originates from an interior pointer
// calculation op.
mlir::Value baseRef =
llvm::TypeSwitch<mlir::Operation *, mlir::Value>(op)
.Case<hlfir::DesignateOp>([&](auto op) {
// Get the base object.
return op.getMemref();
})
.Case<fir::ArrayCoorOp, fir::cg::XArrayCoorOp>([&](auto op) {
// Get the base array on which the coordinate is being applied.
return op.getMemref();
})
.Case<fir::CoordinateOp>([&](auto op) {
// For coordinate operation which is applied on derived type
// object, get the base object.
return op.getRef();
})
.Default([&](mlir::Operation *) { return varPtr; });

return mlir::cast<mlir::TypedValue<mlir::acc::PointerLikeType>>(baseRef);
}

static mlir::acc::VariableTypeCategory
categorizePointee(mlir::Type pointer,
mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
mlir::Type varType) {
// FIR uses operations to compute interior pointers.
// So for example, an array element or composite field access to a float
// value would both be represented as !fir.ref<f32>. We do not want to treat
// such a reference as a scalar. Thus unwrap interior pointer calculations.
auto baseRef = getBaseRef(varPtr);
mlir::Type eleTy = baseRef.getType().getElementType();

if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
return mappableTy.getTypeCategory(varPtr);

if (isScalarLike(eleTy))
return mlir::acc::VariableTypeCategory::scalar;
if (isArrayLike(eleTy))
return mlir::acc::VariableTypeCategory::array;
if (isCompositeLike(eleTy))
return mlir::acc::VariableTypeCategory::composite;
if (mlir::isa<fir::CharacterType, mlir::FunctionType>(eleTy))
return mlir::acc::VariableTypeCategory::nonscalar;
// "pointers" - in the sense of raw address point-of-view, are considered
// scalars. However
if (mlir::isa<fir::LLVMPointerType>(eleTy))
return mlir::acc::VariableTypeCategory::scalar;

// Without further checking, this type cannot be categorized.
return mlir::acc::VariableTypeCategory::uncategorized;
}

template <>
mlir::acc::VariableTypeCategory
OpenACCPointerLikeModel<fir::ReferenceType>::getPointeeTypeCategory(
mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
mlir::Type varType) const {
return categorizePointee(pointer, varPtr, varType);
}

template <>
mlir::acc::VariableTypeCategory
OpenACCPointerLikeModel<fir::PointerType>::getPointeeTypeCategory(
mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
mlir::Type varType) const {
return categorizePointee(pointer, varPtr, varType);
}

template <>
mlir::acc::VariableTypeCategory
OpenACCPointerLikeModel<fir::HeapType>::getPointeeTypeCategory(
mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
mlir::Type varType) const {
return categorizePointee(pointer, varPtr, varType);
}

template <>
mlir::acc::VariableTypeCategory
OpenACCPointerLikeModel<fir::LLVMPointerType>::getPointeeTypeCategory(
mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
mlir::Type varType) const {
return categorizePointee(pointer, varPtr, varType);
}

} // namespace fir::acc
9 changes: 9 additions & 0 deletions flang/lib/Optimizer/OpenACC/RegisterOpenACCExtensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ void registerOpenACCExtensions(mlir::DialectRegistry &registry) {
fir::SequenceType::attachInterface<OpenACCMappableModel<fir::SequenceType>>(
*ctx);
fir::BoxType::attachInterface<OpenACCMappableModel<fir::BaseBoxType>>(*ctx);

fir::ReferenceType::attachInterface<
OpenACCPointerLikeModel<fir::ReferenceType>>(*ctx);
fir::PointerType::attachInterface<
OpenACCPointerLikeModel<fir::PointerType>>(*ctx);
fir::HeapType::attachInterface<OpenACCPointerLikeModel<fir::HeapType>>(
*ctx);
fir::LLVMPointerType::attachInterface<
OpenACCPointerLikeModel<fir::LLVMPointerType>>(*ctx);
});
}

Expand Down
2 changes: 2 additions & 0 deletions flang/test/Fir/OpenACC/openacc-mappable.fir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<f16 = dense<16> : vector<2xi64>,

// CHECK: Visiting: %{{.*}} = acc.copyin var(%{{.*}} : !fir.box<!fir.array<10xf32>>) -> !fir.box<!fir.array<10xf32>> {name = "arr", structured = false}
// CHECK: Mappable: !fir.box<!fir.array<10xf32>>
// CHECK: Type category: array
// CHECK: Size: 40
// CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "arr", structured = false}
// CHECK: Mappable: !fir.array<10xf32>
// CHECK: Type category: array
// CHECK: Size: 40
49 changes: 49 additions & 0 deletions flang/test/Fir/OpenACC/openacc-type-categories.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
! RUN: bbc -fopenacc -emit-hlfir %s -o - | fir-opt -pass-pipeline='builtin.module(test-fir-openacc-interfaces)' --mlir-disable-threading 2>&1 | FileCheck %s

program main
real :: scalar
real, allocatable :: scalaralloc
type tt
real :: field
real :: fieldarray(10)
end type tt
type(tt) :: ttvar
real :: arrayconstsize(10)
real, allocatable :: arrayalloc(:)
complex :: complexvar
character*1 :: charvar

!$acc enter data copyin(scalar, scalaralloc, ttvar, arrayconstsize, arrayalloc)
!$acc enter data copyin(complexvar, charvar, ttvar%field, ttvar%fieldarray, arrayconstsize(1))
end program

! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "scalar", structured = false}
! CHECK: Pointer-like: !fir.ref<f32>
! CHECK: Type category: scalar
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "scalaralloc", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.box<!fir.heap<f32>>>
! CHECK: Type category: nonscalar
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "ttvar", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.type<_QFTtt{field:f32,fieldarray:!fir.array<10xf32>}>>
! CHECK: Type category: composite
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "arrayconstsize", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
! CHECK: Type category: array
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "arrayalloc", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
! CHECK: Type category: array
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "complexvar", structured = false}
! CHECK: Pointer-like: !fir.ref<complex<f32>>
! CHECK: Type category: scalar
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "charvar", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.char<1>>
! CHECK: Type category: nonscalar
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "ttvar%field", structured = false}
! CHECK: Pointer-like: !fir.ref<f32>
! CHECK: Type category: composite
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "ttvar%fieldarray", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
! CHECK: Type category: array
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "arrayconstsize(1)", structured = false}
! CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
! CHECK: Type category: array
Loading