Skip to content

Commit 8301e48

Browse files
authored
[flang][FIR] add FirAliasAnalysisOpInterface (#68317)
This interface allows (HL)FIR passes to add TBAA information to fir.load and fir.store. If present, these TBAA tags take precedence over those added during CodeGen. We can't reuse mlir::LLVMIR::AliasAnalysisOpInterface because that uses the mlir::LLVMIR namespace so it tries to define methods for fir operations in the wrong namespace. But I did re-use the tbaa tag type to minimise boilerplate code. The new builders are to preserve the old interface without the tbaa tag.
1 parent 2b74db6 commit 8301e48

File tree

12 files changed

+215
-10
lines changed

12 files changed

+215
-10
lines changed

flang/include/flang/Optimizer/Dialect/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ set(LLVM_TARGET_DEFINITIONS FortranVariableInterface.td)
1818
mlir_tablegen(FortranVariableInterface.h.inc -gen-op-interface-decls)
1919
mlir_tablegen(FortranVariableInterface.cpp.inc -gen-op-interface-defs)
2020

21+
set(LLVM_TARGET_DEFINITIONS FirAliasTagOpInterface.td)
22+
mlir_tablegen(FirAliasTagOpInterface.h.inc -gen-op-interface-decls)
23+
mlir_tablegen(FirAliasTagOpInterface.cpp.inc -gen-op-interface-defs)
24+
2125
set(LLVM_TARGET_DEFINITIONS CanonicalizationPatterns.td)
2226
mlir_tablegen(CanonicalizationPatterns.inc -gen-rewriters)
2327
add_public_tablegen_target(CanonicalizationPatternsIncGen)

flang/include/flang/Optimizer/Dialect/FIRDialect.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def fir_Dialect : Dialect {
3030
let dependentDialects = [
3131
// Arith dialect provides FastMathFlagsAttr
3232
// supported by some FIR operations.
33-
"arith::ArithDialect"
33+
"arith::ArithDialect",
34+
// TBAA Tag types
35+
"LLVM::LLVMDialect"
3436
];
3537
}
3638

flang/include/flang/Optimizer/Dialect/FIROps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "flang/Optimizer/Dialect/FIRAttr.h"
1313
#include "flang/Optimizer/Dialect/FIRType.h"
14+
#include "flang/Optimizer/Dialect/FirAliasTagOpInterface.h"
1415
#include "flang/Optimizer/Dialect/FortranVariableInterface.h"
1516
#include "mlir/Dialect/Arith/IR/Arith.h"
1617
#include "mlir/Dialect/Func/IR/FuncOps.h"

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
include "mlir/Dialect/Arith/IR/ArithBase.td"
1818
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
19+
include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
1920
include "flang/Optimizer/Dialect/FIRDialect.td"
2021
include "flang/Optimizer/Dialect/FIRTypes.td"
2122
include "flang/Optimizer/Dialect/FIRAttr.td"
2223
include "flang/Optimizer/Dialect/FortranVariableInterface.td"
24+
include "flang/Optimizer/Dialect/FirAliasTagOpInterface.td"
2325
include "mlir/IR/BuiltinAttributes.td"
2426

2527
// Base class for FIR operations.
@@ -258,7 +260,7 @@ def fir_FreeMemOp : fir_Op<"freemem", [MemoryEffects<[MemFree]>]> {
258260
let assemblyFormat = "$heapref attr-dict `:` qualified(type($heapref))";
259261
}
260262

261-
def fir_LoadOp : fir_OneResultOp<"load", []> {
263+
def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface]> {
262264
let summary = "load a value from a memory reference";
263265
let description = [{
264266
Load a value from a memory reference into an ssa-value (virtual register).
@@ -274,9 +276,11 @@ def fir_LoadOp : fir_OneResultOp<"load", []> {
274276
or null.
275277
}];
276278

277-
let arguments = (ins Arg<AnyReferenceLike, "", [MemRead]>:$memref);
279+
let arguments = (ins Arg<AnyReferenceLike, "", [MemRead]>:$memref,
280+
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa);
278281

279-
let builders = [OpBuilder<(ins "mlir::Value":$refVal)>];
282+
let builders = [OpBuilder<(ins "mlir::Value":$refVal)>,
283+
OpBuilder<(ins "mlir::Type":$resTy, "mlir::Value":$refVal)>];
280284

281285
let hasCustomAssemblyFormat = 1;
282286

@@ -285,7 +289,7 @@ def fir_LoadOp : fir_OneResultOp<"load", []> {
285289
}];
286290
}
287291

288-
def fir_StoreOp : fir_Op<"store", []> {
292+
def fir_StoreOp : fir_Op<"store", [FirAliasTagOpInterface]> {
289293
let summary = "store an SSA-value to a memory location";
290294

291295
let description = [{
@@ -305,7 +309,10 @@ def fir_StoreOp : fir_Op<"store", []> {
305309
}];
306310

307311
let arguments = (ins AnyType:$value,
308-
Arg<AnyReferenceLike, "", [MemWrite]>:$memref);
312+
Arg<AnyReferenceLike, "", [MemWrite]>:$memref,
313+
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa);
314+
315+
let builders = [OpBuilder<(ins "mlir::Value":$value, "mlir::Value":$memref)>];
309316

310317
let hasCustomAssemblyFormat = 1;
311318
let hasVerifier = 1;
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- FirAliasTagOpInterface.h ---------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains an interface for adding alias analysis information to
10+
// loads and stores
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef FORTRAN_OPTIMIZER_DIALECT_FIR_ALIAS_TAG_OP_INTERFACE_H
15+
#define FORTRAN_OPTIMIZER_DIALECT_FIR_ALIAS_TAG_OP_INTERFACE_H
16+
17+
#include "mlir/IR/OpDefinition.h"
18+
#include "mlir/IR/Operation.h"
19+
#include "mlir/Support/LogicalResult.h"
20+
21+
namespace fir::detail {
22+
mlir::LogicalResult verifyFirAliasTagOpInterface(mlir::Operation *op);
23+
} // namespace fir::detail
24+
25+
#include "flang/Optimizer/Dialect/FirAliasTagOpInterface.h.inc"
26+
27+
#endif // FORTRAN_OPTIMIZER_DIALECT_FIR_ALIAS_TAG_OP_INTERFACE_H
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//===-- FirAliasTagOpInterface.td --------------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
include "mlir/IR/Interfaces.td"
10+
11+
def FirAliasTagOpInterface : OpInterface<"FirAliasTagOpInterface"> {
12+
let description = [{
13+
An interface for memory operations that can carry alias analysis metadata.
14+
It provides setters and getters for the operation's alias analysis
15+
attributes. The default implementations of the interface methods expect
16+
the operation to have an attribute of type ArrayAttr named tbaa.
17+
Unlike the mlir::LLVM::AliasAnalysisOpInterface, this only supports tbaa.
18+
}];
19+
20+
let cppNamespace = "::fir";
21+
let verify = [{ return detail::verifyFirAliasTagOpInterface($_op); }];
22+
23+
let methods = [
24+
InterfaceMethod<
25+
/*desc=*/ "Returns the tbaa attribute or nullptr",
26+
/*returnType=*/ "mlir::ArrayAttr",
27+
/*methodName=*/ "getTBAATagsOrNull",
28+
/*args=*/ (ins),
29+
/*methodBody=*/ [{}],
30+
/*defaultImpl=*/ [{
31+
auto op = mlir::cast<ConcreteOp>(this->getOperation());
32+
return op.getTbaaAttr();
33+
}]
34+
>,
35+
InterfaceMethod<
36+
/*desc=*/ "Sets the tbaa attribute",
37+
/*returnType=*/ "void",
38+
/*methodName=*/ "setTBAATags",
39+
/*args=*/ (ins "const mlir::ArrayAttr":$attr),
40+
/*methodBody=*/ [{}],
41+
/*defaultImpl=*/ [{
42+
auto op = mlir::cast<ConcreteOp>(this->getOperation());
43+
op.setTbaaAttr(attr);
44+
}]
45+
>,
46+
InterfaceMethod<
47+
/*desc=*/ "Returns a list of all pointer operands accessed by the "
48+
"operation",
49+
/*returnType=*/ "::llvm::SmallVector<::mlir::Value>",
50+
/*methodName=*/ "getAccessedOperands",
51+
/*args=*/ (ins),
52+
/*methodBody=*/ [{}],
53+
/*defaultImpl=*/ [{
54+
auto op = mlir::cast<ConcreteOp>(this->getOperation());
55+
return {op.getMemref()};
56+
}]
57+
>
58+
];
59+
}

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3086,7 +3086,10 @@ struct LoadOpConversion : public FIROpConversion<fir::LoadOp> {
30863086
auto boxValue = rewriter.create<mlir::LLVM::LoadOp>(
30873087
loc, boxPtrTy.cast<mlir::LLVM::LLVMPointerType>().getElementType(),
30883088
inputBoxStorage);
3089-
attachTBAATag(boxValue, boxTy, boxTy, nullptr);
3089+
if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
3090+
boxValue.setTBAATags(*optionalTag);
3091+
else
3092+
attachTBAATag(boxValue, boxTy, boxTy, nullptr);
30903093
auto newBoxStorage =
30913094
genAllocaWithType(loc, boxPtrTy, defaultAlign, rewriter);
30923095
auto storeOp =
@@ -3097,7 +3100,10 @@ struct LoadOpConversion : public FIROpConversion<fir::LoadOp> {
30973100
mlir::Type loadTy = convertType(load.getType());
30983101
auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
30993102
load.getLoc(), loadTy, adaptor.getOperands(), load->getAttrs());
3100-
attachTBAATag(loadOp, load.getType(), load.getType(), nullptr);
3103+
if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
3104+
loadOp.setTBAATags(*optionalTag);
3105+
else
3106+
attachTBAATag(loadOp, load.getType(), load.getType(), nullptr);
31013107
rewriter.replaceOp(load, loadOp.getResult());
31023108
}
31033109
return mlir::success();
@@ -3341,7 +3347,10 @@ struct StoreOpConversion : public FIROpConversion<fir::StoreOp> {
33413347
newStoreOp = rewriter.create<mlir::LLVM::StoreOp>(
33423348
loc, adaptor.getOperands()[0], adaptor.getOperands()[1]);
33433349
}
3344-
attachTBAATag(newStoreOp, storeTy, storeTy, nullptr);
3350+
if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())
3351+
newStoreOp.setTBAATags(*optionalTag);
3352+
else
3353+
attachTBAATag(newStoreOp, storeTy, storeTy, nullptr);
33453354
rewriter.eraseOp(store);
33463355
return mlir::success();
33473356
}

flang/lib/Optimizer/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_flang_library(FIRDialect
66
FIROps.cpp
77
FIRType.cpp
88
FortranVariableInterface.cpp
9+
FirAliasTagOpInterface.cpp
910
Inliner.cpp
1011

1112
DEPENDS

flang/lib/Optimizer/Dialect/FIRDialect.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "flang/Optimizer/Dialect/FIRAttr.h"
1515
#include "flang/Optimizer/Dialect/FIROps.h"
1616
#include "flang/Optimizer/Dialect/FIRType.h"
17+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1718
#include "mlir/Transforms/InliningUtils.h"
1819

1920
using namespace fir;
@@ -58,6 +59,7 @@ struct FIRInlinerInterface : public mlir::DialectInlinerInterface {
5859

5960
fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx)
6061
: mlir::Dialect("fir", ctx, mlir::TypeID::get<FIROpsDialect>()) {
62+
getContext()->loadDialect<mlir::LLVM::LLVMDialect>();
6163
registerTypes();
6264
registerAttributes();
6365
addOperations<

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1977,8 +1977,18 @@ void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
19771977
mlir::emitError(result.location, "not a memory reference type");
19781978
return;
19791979
}
1980+
build(builder, result, eleTy, refVal);
1981+
}
1982+
1983+
void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
1984+
mlir::Type resTy, mlir::Value refVal) {
1985+
1986+
if (!refVal) {
1987+
mlir::emitError(result.location, "LoadOp has null argument");
1988+
return;
1989+
}
19801990
result.addOperands(refVal);
1981-
result.addTypes(eleTy);
1991+
result.addTypes(resTy);
19821992
}
19831993

19841994
mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) {
@@ -3249,6 +3259,11 @@ mlir::LogicalResult fir::StoreOp::verify() {
32493259
return mlir::success();
32503260
}
32513261

3262+
void fir::StoreOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
3263+
mlir::Value value, mlir::Value memref) {
3264+
build(builder, result, value, memref, {});
3265+
}
3266+
32523267
//===----------------------------------------------------------------------===//
32533268
// StringLitOp
32543269
//===----------------------------------------------------------------------===//
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===-- FirAliasTagOpInterface.cpp ----------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "flang/Optimizer/Dialect/FirAliasTagOpInterface.h"
14+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15+
16+
#include "flang/Optimizer/Dialect/FirAliasTagOpInterface.cpp.inc"
17+
18+
mlir::LogicalResult
19+
fir::detail::verifyFirAliasTagOpInterface(mlir::Operation *op) {
20+
auto iface = mlir::cast<FirAliasTagOpInterface>(op);
21+
22+
mlir::ArrayAttr tags = iface.getTBAATagsOrNull();
23+
if (!tags)
24+
return mlir::success();
25+
26+
for (mlir::Attribute iter : tags)
27+
if (!mlir::isa<mlir::LLVM::TBAATagAttr>(iter))
28+
return op->emitOpError("expected op to return array of ")
29+
<< mlir::LLVM::TBAATagAttr::getMnemonic() << " attributes";
30+
return mlir::success();
31+
}

flang/test/Fir/tbaa-codegen.fir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// test that tbaa attributes can be added to fir.load and fir.store
2+
// and that these attributes are propagated to LLVMIR
3+
4+
// RUN: tco %s | FileCheck %s
5+
6+
// subroutine simple(a)
7+
// integer, intent(inout) :: a(:)
8+
// a(1) = a(2)
9+
// end subroutine
10+
#tbaa_root = #llvm.tbaa_root<id = "Flang function root _QPsimple">
11+
#tbaa_type_desc = #llvm.tbaa_type_desc<id = "any access", members = {<#tbaa_root, 0>}>
12+
#tbaa_type_desc1 = #llvm.tbaa_type_desc<id = "any data access", members = {<#tbaa_type_desc, 0>}>
13+
#tbaa_type_desc2 = #llvm.tbaa_type_desc<id = "dummy arg data", members = {<#tbaa_type_desc1, 0>}>
14+
#tbaa_type_desc3 = #llvm.tbaa_type_desc<id = "dummy arg data/a", members = {<#tbaa_type_desc2, 0>}>
15+
#tbaa_tag = #llvm.tbaa_tag<base_type = #tbaa_type_desc3, access_type = #tbaa_type_desc3, offset = 0>
16+
module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.target_triple = "aarch64-unknown-linux-gnu"} {
17+
func.func @_QPsimple(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"}) {
18+
%c1 = arith.constant 1 : index
19+
%c2 = arith.constant 2 : index
20+
%0 = fir.declare %arg0 {fortran_attrs = #fir.var_attrs<intent_inout>, uniq_name = "_QFfuncEa"} : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
21+
%1 = fir.rebox %0 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
22+
%2 = fir.array_coor %1 %c2 : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
23+
%3 = fir.load %2 {tbaa = [#tbaa_tag]} : !fir.ref<i32>
24+
%4 = fir.array_coor %1 %c1 : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
25+
fir.store %3 to %4 {tbaa = [#tbaa_tag]} : !fir.ref<i32>
26+
return
27+
}
28+
}
29+
30+
// CHECK-LABEL: define void @_QPsimple(
31+
// CHECK-SAME: ptr %[[ARG0:.*]]) {
32+
// [...]
33+
// load a(2):
34+
// CHECK: %[[VAL20:.*]] = getelementptr i8, ptr %{{.*}}, i64 %{{.*}}
35+
// CHECK: %[[A2:.*]] = load i32, ptr %[[VAL20]], align 4, !tbaa ![[A_ACCESS_TAG:.*]]
36+
// [...]
37+
// store a(2) to a(1):
38+
// CHECK: %[[A1:.*]] = getelementptr i8, ptr %{{.*}}, i64 %{{.*}}
39+
// CHECK: store i32 %[[A2]], ptr %[[A1]], align 4, !tbaa ![[A_ACCESS_TAG]]
40+
// CHECK: ret void
41+
// CHECK: }
42+
// CHECK: ![[A_ACCESS_TAG]] = !{![[A_ACCESS_TYPE:.*]], ![[A_ACCESS_TYPE]], i64 0}
43+
// CHECK: ![[A_ACCESS_TYPE]] = !{!"dummy arg data/a", ![[DUMMY_ARG_TYPE:.*]], i64 0}
44+
// CHECK: ![[DUMMY_ARG_TYPE]] = !{!"dummy arg data", ![[DATA_ACCESS_TYPE:.*]], i64 0}
45+
// CHECK: ![[DATA_ACCESS_TYPE]] = !{!"any data access", ![[ANY_ACCESS_TYPE:.*]], i64 0}
46+
// CHECK: ![[ANY_ACCESS_TYPE]] = !{!"any access", ![[ROOT:.*]], i64 0}
47+
// CHECK: ![[ROOT]] = !{!"Flang function root _QPsimple"}

0 commit comments

Comments
 (0)