Skip to content

FIR TBAA Pass #68414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions flang/include/flang/Optimizer/Analysis/TBAAForest.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//===-- TBAAForest.h - A TBAA tree for each function -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_OPTIMIZER_ANALYSIS_TBAA_FOREST_H
#define FORTRAN_OPTIMIZER_ANALYSIS_TBAA_FOREST_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/DenseMap.h"
#include <string>

namespace fir {

//===----------------------------------------------------------------------===//
// TBAATree
//===----------------------------------------------------------------------===//
/// Per-function TBAA tree. Each tree contains branches for data (of various
/// kinds) and descriptor access
struct TBAATree {
//===----------------------------------------------------------------------===//
// TBAAForrest::TBAATree::SubtreeState
//===----------------------------------------------------------------------===//
/// This contains a TBAA subtree based on some parent. New tags can be added
/// under the parent using getTag.
class SubtreeState {
friend TBAATree; // only allow construction by TBAATree
public:
SubtreeState() = delete;
SubtreeState(const SubtreeState &) = delete;
SubtreeState(SubtreeState &&) = default;

mlir::LLVM::TBAATagAttr getTag(llvm::StringRef uniqueId) const;

private:
SubtreeState(mlir::MLIRContext *ctx, std::string name,
mlir::LLVM::TBAANodeAttr grandParent)
: parentId{std::move(name)}, context(ctx) {
parent = mlir::LLVM::TBAATypeDescriptorAttr::get(
context, parentId, mlir::LLVM::TBAAMemberAttr::get(grandParent, 0));
}

const std::string parentId;
mlir::MLIRContext *const context;
mlir::LLVM::TBAATypeDescriptorAttr parent;
llvm::DenseMap<llvm::StringRef, mlir::LLVM::TBAATagAttr> tagDedup;
};

SubtreeState globalDataTree;
SubtreeState allocatedDataTree;
SubtreeState dummyArgDataTree;
mlir::LLVM::TBAATypeDescriptorAttr anyAccessDesc;
mlir::LLVM::TBAATypeDescriptorAttr boxMemberTypeDesc;
mlir::LLVM::TBAATypeDescriptorAttr anyDataTypeDesc;

static TBAATree buildTree(mlir::StringAttr functionName);

private:
TBAATree(mlir::LLVM::TBAATypeDescriptorAttr anyAccess,
mlir::LLVM::TBAATypeDescriptorAttr dataRoot,
mlir::LLVM::TBAATypeDescriptorAttr boxMemberTypeDesc);
};

//===----------------------------------------------------------------------===//
// TBAAForrest
//===----------------------------------------------------------------------===//
/// Collection of TBAATrees, usually indexed by function (so that each function
/// has a different TBAATree)
class TBAAForrest {
public:
explicit TBAAForrest(bool separatePerFunction = true)
: separatePerFunction{separatePerFunction} {}

inline const TBAATree &operator[](mlir::func::FuncOp func) {
return getFuncTree(func.getSymNameAttr());
}
inline const TBAATree &operator[](mlir::LLVM::LLVMFuncOp func) {
return getFuncTree(func.getSymNameAttr());
}

private:
const TBAATree &getFuncTree(mlir::StringAttr symName) {
if (!separatePerFunction)
symName = mlir::StringAttr::get(symName.getContext(), "");
if (!trees.contains(symName))
trees.insert({symName, TBAATree::buildTree(symName)});
return trees.at(symName);
}

// Should each function use a different tree?
const bool separatePerFunction;
// TBAA tree per function
llvm::DenseMap<mlir::StringAttr, TBAATree> trees;
};

} // namespace fir

#endif // FORTRAN_OPTIMIZER_ANALYSIS_TBAA_FOREST_H
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ set(LLVM_TARGET_DEFINITIONS FortranVariableInterface.td)
mlir_tablegen(FortranVariableInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(FortranVariableInterface.cpp.inc -gen-op-interface-defs)

set(LLVM_TARGET_DEFINITIONS FirAliasTagOpInterface.td)
mlir_tablegen(FirAliasTagOpInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(FirAliasTagOpInterface.cpp.inc -gen-op-interface-defs)

set(LLVM_TARGET_DEFINITIONS CanonicalizationPatterns.td)
mlir_tablegen(CanonicalizationPatterns.inc -gen-rewriters)
add_public_tablegen_target(CanonicalizationPatternsIncGen)
Expand Down
4 changes: 3 additions & 1 deletion flang/include/flang/Optimizer/Dialect/FIRDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def fir_Dialect : Dialect {
let dependentDialects = [
// Arith dialect provides FastMathFlagsAttr
// supported by some FIR operations.
"arith::ArithDialect"
"arith::ArithDialect",
// TBAA Tag types
"LLVM::LLVMDialect"
];
}

Expand Down
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Dialect/FIROps.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/FirAliasTagOpInterface.h"
#include "flang/Optimizer/Dialect/FortranVariableInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down
17 changes: 12 additions & 5 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
include "flang/Optimizer/Dialect/FIRDialect.td"
include "flang/Optimizer/Dialect/FIRTypes.td"
include "flang/Optimizer/Dialect/FIRAttr.td"
include "flang/Optimizer/Dialect/FortranVariableInterface.td"
include "flang/Optimizer/Dialect/FirAliasTagOpInterface.td"
include "mlir/IR/BuiltinAttributes.td"

// Base class for FIR operations.
Expand Down Expand Up @@ -258,7 +260,7 @@ def fir_FreeMemOp : fir_Op<"freemem", [MemoryEffects<[MemFree]>]> {
let assemblyFormat = "$heapref attr-dict `:` qualified(type($heapref))";
}

def fir_LoadOp : fir_OneResultOp<"load", []> {
def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface]> {
let summary = "load a value from a memory reference";
let description = [{
Load a value from a memory reference into an ssa-value (virtual register).
Expand All @@ -274,9 +276,11 @@ def fir_LoadOp : fir_OneResultOp<"load", []> {
or null.
}];

let arguments = (ins Arg<AnyReferenceLike, "", [MemRead]>:$memref);
let arguments = (ins Arg<AnyReferenceLike, "", [MemRead]>:$memref,
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa);

let builders = [OpBuilder<(ins "mlir::Value":$refVal)>];
let builders = [OpBuilder<(ins "mlir::Value":$refVal)>,
OpBuilder<(ins "mlir::Type":$resTy, "mlir::Value":$refVal)>];

let hasCustomAssemblyFormat = 1;

Expand All @@ -285,7 +289,7 @@ def fir_LoadOp : fir_OneResultOp<"load", []> {
}];
}

def fir_StoreOp : fir_Op<"store", []> {
def fir_StoreOp : fir_Op<"store", [FirAliasTagOpInterface]> {
let summary = "store an SSA-value to a memory location";

let description = [{
Expand All @@ -305,7 +309,10 @@ def fir_StoreOp : fir_Op<"store", []> {
}];

let arguments = (ins AnyType:$value,
Arg<AnyReferenceLike, "", [MemWrite]>:$memref);
Arg<AnyReferenceLike, "", [MemWrite]>:$memref,
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa);

let builders = [OpBuilder<(ins "mlir::Value":$value, "mlir::Value":$memref)>];

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
Expand Down
27 changes: 27 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FirAliasTagOpInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- FirAliasTagOpInterface.h ---------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains an interface for adding alias analysis information to
// loads and stores
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_OPTIMIZER_DIALECT_FIR_ALIAS_TAG_OP_INTERFACE_H
#define FORTRAN_OPTIMIZER_DIALECT_FIR_ALIAS_TAG_OP_INTERFACE_H

#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"

namespace fir::detail {
mlir::LogicalResult verifyFirAliasTagOpInterface(mlir::Operation *op);
} // namespace fir::detail

#include "flang/Optimizer/Dialect/FirAliasTagOpInterface.h.inc"

#endif // FORTRAN_OPTIMIZER_DIALECT_FIR_ALIAS_TAG_OP_INTERFACE_H
59 changes: 59 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FirAliasTagOpInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
//===-- FirAliasTagOpInterface.td --------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

include "mlir/IR/Interfaces.td"

def FirAliasTagOpInterface : OpInterface<"FirAliasTagOpInterface"> {
let description = [{
An interface for memory operations that can carry alias analysis metadata.
It provides setters and getters for the operation's alias analysis
attributes. The default implementations of the interface methods expect
the operation to have an attribute of type ArrayAttr named tbaa.
Unlike the mlir::LLVM::AliasAnalysisOpInterface, this only supports tbaa.
}];

let cppNamespace = "::fir";
let verify = [{ return detail::verifyFirAliasTagOpInterface($_op); }];

let methods = [
InterfaceMethod<
/*desc=*/ "Returns the tbaa attribute or nullptr",
/*returnType=*/ "mlir::ArrayAttr",
/*methodName=*/ "getTBAATagsOrNull",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = mlir::cast<ConcreteOp>(this->getOperation());
return op.getTbaaAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Sets the tbaa attribute",
/*returnType=*/ "void",
/*methodName=*/ "setTBAATags",
/*args=*/ (ins "const mlir::ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = mlir::cast<ConcreteOp>(this->getOperation());
op.setTbaaAttr(attr);
}]
>,
InterfaceMethod<
/*desc=*/ "Returns a list of all pointer operands accessed by the "
"operation",
/*returnType=*/ "::llvm::SmallVector<::mlir::Value>",
/*methodName=*/ "getAccessedOperands",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = mlir::cast<ConcreteOp>(this->getOperation());
return {op.getMemref()};
}]
>
];
}
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ std::unique_ptr<mlir::Pass> createMemDataFlowOptPass();
std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
std::unique_ptr<mlir::Pass> createMemoryAllocationPass();
std::unique_ptr<mlir::Pass> createStackArraysPass();
std::unique_ptr<mlir::Pass> createAliasTagsPass();
std::unique_ptr<mlir::Pass> createSimplifyIntrinsicsPass();
std::unique_ptr<mlir::Pass> createAddDebugFoundationPass();
std::unique_ptr<mlir::Pass> createLoopVersioningPass();
Expand Down
20 changes: 20 additions & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,26 @@ def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
let constructor = "::fir::createStackArraysPass()";
}

def AddAliasTags : Pass<"fir-add-alias-tags", "mlir::ModuleOp"> {
let summary = "Add tbaa tags to operations that implement FirAliasAnalysisOpInterface";
let description = [{
TBAA (type based alias analysis) is one method to pass pointer alias information
from language frontends to LLVM. This pass uses fir::AliasAnalysis to add this
information to fir.load and fir.store operations.
Additional tags are added during codegen. See fir::TBAABuilder.
This needs to be a separate pass so that it happens before structured control
flow operations are lowered to branches and basic blocks (this makes tracing
the source of values much eaiser). The other TBAA tags need to be applied to
box loads and stores which are implicit in FIR and so cannot be annotated
until codegen.
TODO: this is currently a pass on mlir::ModuleOp to avoid parallelism. In
theory, each operation could be considered in prallel, so long as there
aren't races adding new tags to the mlir context.
}];
let dependentDialects = [ "fir::FIROpsDialect" ];
let constructor = "::fir::createAliasTagsPass()";
}

def SimplifyRegionLite : Pass<"simplify-region-lite", "mlir::ModuleOp"> {
let summary = "Region simplification";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_flang_library(FIRAnalysis
AliasAnalysis.cpp
TBAAForest.cpp

DEPENDS
FIRDialect
Expand Down
60 changes: 60 additions & 0 deletions flang/lib/Optimizer/Analysis/TBAAForest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//===- TBAAForest.cpp - Per-functon TBAA Trees ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Analysis/TBAAForest.h"
#include <mlir/Dialect/LLVMIR/LLVMAttrs.h>

mlir::LLVM::TBAATagAttr
fir::TBAATree::SubtreeState::getTag(llvm::StringRef uniqueName) const {
// mlir::LLVM::TBAATagAttr &tag = tagDedup[uniqueName];
// if (tag)
// return tag;
std::string id = (parentId + "/" + uniqueName).str();
mlir::LLVM::TBAATypeDescriptorAttr type =
mlir::LLVM::TBAATypeDescriptorAttr::get(
context, id, mlir::LLVM::TBAAMemberAttr::get(parent, 0));
return mlir::LLVM::TBAATagAttr::get(type, type, 0);
// return tag;
}

fir::TBAATree fir::TBAATree::buildTree(mlir::StringAttr func) {
llvm::StringRef funcName = func.getValue();
std::string rootId = ("Flang function root " + funcName).str();
mlir::MLIRContext *ctx = func.getContext();
mlir::LLVM::TBAARootAttr funcRoot =
mlir::LLVM::TBAARootAttr::get(ctx, mlir::StringAttr::get(ctx, rootId));

static constexpr llvm::StringRef anyAccessTypeDescId = "any access";
mlir::LLVM::TBAATypeDescriptorAttr anyAccess =
mlir::LLVM::TBAATypeDescriptorAttr::get(
ctx, anyAccessTypeDescId,
mlir::LLVM::TBAAMemberAttr::get(funcRoot, 0));

static constexpr llvm::StringRef anyDataAccessTypeDescId = "any data access";
mlir::LLVM::TBAATypeDescriptorAttr dataRoot =
mlir::LLVM::TBAATypeDescriptorAttr::get(
ctx, anyDataAccessTypeDescId,
mlir::LLVM::TBAAMemberAttr::get(anyAccess, 0));

static constexpr llvm::StringRef boxMemberTypeDescId = "descriptor member";
mlir::LLVM::TBAATypeDescriptorAttr boxMemberTypeDesc =
mlir::LLVM::TBAATypeDescriptorAttr::get(
ctx, boxMemberTypeDescId,
mlir::LLVM::TBAAMemberAttr::get(anyAccess, 0));

return TBAATree{anyAccess, dataRoot, boxMemberTypeDesc};
}

fir::TBAATree::TBAATree(mlir::LLVM::TBAATypeDescriptorAttr anyAccess,
mlir::LLVM::TBAATypeDescriptorAttr dataRoot,
mlir::LLVM::TBAATypeDescriptorAttr boxMemberTypeDesc)
: globalDataTree(dataRoot.getContext(), "global data", dataRoot),
allocatedDataTree(dataRoot.getContext(), "allocated data", dataRoot),
dummyArgDataTree(dataRoot.getContext(), "dummy arg data", dataRoot),
anyAccessDesc(anyAccess), boxMemberTypeDesc(boxMemberTypeDesc),
anyDataTypeDesc(dataRoot) {}
Loading