Skip to content

Commit 686b21e

Browse files
committed
Addressed review comments : Added nontemporal attribute to fir.load and fir.store and used that to mark the operations as nontemporal
1 parent 5a46828 commit 686b21e

File tree

7 files changed

+160
-123
lines changed

7 files changed

+160
-123
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def fir_LoadOp : fir_OneResultOp<"load", [FirAliasTagOpInterface,
305305
}];
306306

307307
let arguments = (ins AnyReferenceLike:$memref,
308-
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa);
308+
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa, UnitAttr:$nontemporal);
309309

310310
let builders = [OpBuilder<(ins "mlir::Value":$refVal)>,
311311
OpBuilder<(ins "mlir::Type":$resTy, "mlir::Value":$refVal)>];
@@ -337,9 +337,8 @@ def fir_StoreOp : fir_Op<"store", [FirAliasTagOpInterface,
337337
`%p`, is undefined or null.
338338
}];
339339

340-
let arguments = (ins AnyType:$value,
341-
AnyReferenceLike:$memref,
342-
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa);
340+
let arguments = (ins AnyType:$value, AnyReferenceLike:$memref,
341+
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa, UnitAttr:$nontemporal);
343342

344343
let builders = [OpBuilder<(ins "mlir::Value":$value, "mlir::Value":$memref)>];
345344

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3567,17 +3567,15 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
35673567
newOp = rewriter.create<mlir::LLVM::MemcpyOp>(loc, llvmMemref, llvmValue,
35683568
boxSize, isVolatile);
35693569
} else {
3570-
unsigned alignment =
3571-
store->getAttrOfType<mlir::IntegerAttr>("alignment")
3572-
? store->getAttrOfType<mlir::IntegerAttr>("alignment").getInt()
3573-
: 0;
3574-
3575-
mlir::LLVM::StoreOp storeOp = rewriter.create<mlir::LLVM::StoreOp>(
3576-
loc, llvmValue, llvmMemref, alignment, store->hasAttr("volatile"),
3577-
store->hasAttr("nontemporal"));
3570+
mlir::LLVM::StoreOp storeOp =
3571+
rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
35783572

35793573
if (isVolatile)
35803574
storeOp.setVolatile_(true);
3575+
3576+
if (store.getNontemporal())
3577+
storeOp.setNontemporal(true);
3578+
35813579
newOp = storeOp;
35823580
}
35833581
if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())

flang/lib/Optimizer/OpenMP/LowerNontemporal.cpp

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,62 +10,67 @@
1010
// nontemporal.
1111
//
1212
//===----------------------------------------------------------------------===//
13+
1314
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
1415
#include "flang/Optimizer/OpenMP/Passes.h"
1516
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
17+
1618
using namespace mlir;
19+
1720
namespace flangomp {
1821
#define GEN_PASS_DEF_LOWERNONTEMPORALPASS
1922
#include "flang/Optimizer/OpenMP/Passes.h.inc"
2023
} // namespace flangomp
24+
2125
namespace {
2226
class LowerNontemporalPass
2327
: public flangomp::impl::LowerNontemporalPassBase<LowerNontemporalPass> {
2428
void addNonTemporalAttr(omp::SimdOp simdOp) {
25-
if (!simdOp.getNontemporalVars().empty()) {
26-
llvm::SmallVector<mlir::Value> nontemporalOrigVars;
27-
mlir::OperandRange nontemporals = simdOp.getNontemporalVars();
28-
for (mlir::Value nontemporal : nontemporals) {
29-
nontemporalOrigVars.push_back(nontemporal);
30-
}
31-
std::function<mlir::Value(mlir::Value)> getBaseOperand =
32-
[&](mlir::Value operand) -> mlir::Value {
33-
if (mlir::isa<fir::DeclareOp>(operand.getDefiningOp()))
34-
return operand;
35-
else if (auto arrayCoorOp = llvm::dyn_cast<fir::ArrayCoorOp>(
36-
operand.getDefiningOp())) {
37-
return getBaseOperand(arrayCoorOp.getMemref());
38-
} else if (auto boxAddrOp = llvm::dyn_cast<fir::BoxAddrOp>(
39-
operand.getDefiningOp())) {
40-
return getBaseOperand(boxAddrOp.getVal());
41-
} else if (auto loadOp =
42-
llvm::dyn_cast<fir::LoadOp>(operand.getDefiningOp())) {
43-
return getBaseOperand(loadOp.getMemref());
44-
} else {
45-
return operand;
46-
}
47-
};
48-
simdOp->walk([&](Operation *op) {
49-
mlir::Value Operand = nullptr;
50-
if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op)) {
51-
Operand = loadOp.getMemref();
52-
} else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op)) {
53-
Operand = storeOp.getMemref();
29+
if (simdOp.getNontemporalVars().empty())
30+
return;
31+
32+
std::function<mlir::Value(mlir::Value)> getBaseOperand =
33+
[&](mlir::Value operand) -> mlir::Value {
34+
if (mlir::isa<mlir::BlockArgument>(operand) ||
35+
(mlir::isa<fir::AllocaOp>(operand.getDefiningOp())) ||
36+
(mlir::isa<fir::DeclareOp>(operand.getDefiningOp())))
37+
return operand;
38+
39+
Operation *definingOp = operand.getDefiningOp();
40+
if (definingOp) {
41+
for (Value srcOp : definingOp->getOperands()) {
42+
return getBaseOperand(srcOp);
5443
}
55-
if (Operand && !(fir::isAllocatableType(Operand.getType()) ||
56-
fir::isPointerType((Operand.getType())))) {
57-
Operand = getBaseOperand(Operand);
58-
if (is_contained(nontemporalOrigVars, Operand)) {
59-
// Set the attribute
60-
op->setAttr("nontemporal", UnitAttr::get(op->getContext()));
61-
}
44+
}
45+
return operand;
46+
};
47+
48+
// walk through the operations and mark the load and store as nontemporal
49+
simdOp->walk([&](Operation *op) {
50+
mlir::Value operand = nullptr;
51+
52+
if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op))
53+
operand = loadOp.getMemref();
54+
else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op))
55+
operand = storeOp.getMemref();
56+
57+
if (operand && !(fir::isAllocatableType(operand.getType()) ||
58+
fir::isPointerType((operand.getType())))) {
59+
operand = getBaseOperand(operand);
60+
61+
if (llvm::is_contained(simdOp.getNontemporalVars(), operand)) {
62+
if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op))
63+
loadOp.setNontemporal(true);
64+
else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op))
65+
storeOp.setNontemporal(true);
6266
}
63-
});
67+
}
68+
});
6469
}
65-
}
70+
6671
void runOnOperation() override {
6772
Operation *op = getOperation();
6873
op->walk([&](omp::SimdOp simdOp) { addNonTemporalAttr(simdOp); });
6974
}
7075
};
71-
} // namespace
76+
} // namespace

flang/lib/Optimizer/Passes/Pipelines.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,8 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP,
274274
addNestedPassToAllTopLevelOperations<PassConstructor>(
275275
pm, hlfir::createInlineHLFIRAssign);
276276
pm.addPass(hlfir::createConvertHLFIRtoFIR());
277-
if (enableOpenMP) {
277+
if (enableOpenMP)
278278
pm.addPass(flangomp::createLowerWorkshare());
279-
pm.addPass(flangomp::createLowerNontemporalPass());
280-
}
281279
}
282280

283281
/// Create a pass pipeline for handling certain OpenMP transformations needed
@@ -347,6 +345,10 @@ void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
347345
config.ApproxFuncFPMath, config.NoSignedZerosFPMath, config.UnsafeFPMath,
348346
""}));
349347

348+
if (config.EnableOpenMP)
349+
pm.addNestedPass<mlir::func::FuncOp>(
350+
flangomp::createLowerNontemporalPass());
351+
350352
fir::addFIRToLLVMPass(pm, config);
351353
}
352354

flang/test/Fir/basic-program.fir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ func.func @_QQmain() {
6565
// PASSES-NEXT: InlineHLFIRAssign
6666
// PASSES-NEXT: ConvertHLFIRtoFIR
6767
// PASSES-NEXT: LowerWorkshare
68-
// PASSES-NEXT: 'func.func' Pipeline
69-
// PASSES-NEXT: LowerNontemporalPass
7068
// PASSES-NEXT: CSE
7169
// PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
7270
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
@@ -151,6 +149,7 @@ func.func @_QQmain() {
151149
// PASSES-NEXT: CompilerGeneratedNamesConversion
152150
// PASSES-NEXT: 'func.func' Pipeline
153151
// PASSES-NEXT: FunctionAttr
152+
// PASSES-NEXT: LowerNontemporalPass
154153
// PASSES-NEXT: FIRToLLVMLowering
155154
// PASSES-NEXT: ReconcileUnrealizedCasts
156155
// PASSES-NEXT: LLVMIRLoweringPass

flang/test/Lower/OpenMP/simd-nontemporal.f90

Lines changed: 0 additions & 67 deletions
This file was deleted.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Test lower-nontemporal pass
2+
// RUN: fir-opt --lower-nontemporal %s | FileCheck %s
3+
4+
// CHECK-LABEL: func @_QPsimd_with_nontemporal_clause
5+
func.func @_QPsimd_with_nontemporal_clause(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
6+
%c1_i32 = arith.constant 1 : i32
7+
%0 = fir.dummy_scope : !fir.dscope
8+
%1 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsimd_with_nontemporal_clauseEa"}
9+
// CHECK: %[[A_DECL:.*]] = fir.declare %{{.*}} {uniq_name = "_QFsimd_with_nontemporal_clauseEa"} : (!fir.ref<i32>) -> !fir.ref<i32>
10+
// CHECK: %[[C_DECL:.*]] = fir.declare %{{.*}} {uniq_name = "_QFsimd_with_nontemporal_clauseEc"} : (!fir.ref<i32>) -> !fir.ref<i32>
11+
%2 = fir.declare %1 {uniq_name = "_QFsimd_with_nontemporal_clauseEa"} : (!fir.ref<i32>) -> !fir.ref<i32>
12+
%3 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsimd_with_nontemporal_clauseEb"}
13+
%4 = fir.declare %3 {uniq_name = "_QFsimd_with_nontemporal_clauseEb"} : (!fir.ref<i32>) -> !fir.ref<i32>
14+
%5 = fir.alloca i32 {bindc_name = "c", uniq_name = "_QFsimd_with_nontemporal_clauseEc"}
15+
%6 = fir.declare %5 {uniq_name = "_QFsimd_with_nontemporal_clauseEc"} : (!fir.ref<i32>) -> !fir.ref<i32>
16+
%7 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimd_with_nontemporal_clauseEi"}
17+
%8 = fir.declare %7 {uniq_name = "_QFsimd_with_nontemporal_clauseEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
18+
%9 = fir.declare %arg0 dummy_scope %0 {uniq_name = "_QFsimd_with_nontemporal_clauseEn"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
19+
%10 = fir.load %9 : !fir.ref<i32>
20+
// CHECK: omp.simd nontemporal(%[[A_DECL]], %[[C_DECL]] : !fir.ref<i32>, !fir.ref<i32>) private(@_QFsimd_with_nontemporal_clauseEi_private_i32 %8 -> %arg1 : !fir.ref<i32>) {
21+
// CHECK-NEXT: omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}}) {
22+
omp.simd nontemporal(%2, %6 : !fir.ref<i32>, !fir.ref<i32>) private(@_QFsimd_with_nontemporal_clauseEi_private_i32 %8 -> %arg1 : !fir.ref<i32>) {
23+
omp.loop_nest (%arg2) : i32 = (%c1_i32) to (%10) inclusive step (%c1_i32) {
24+
%11 = fir.declare %arg1 {uniq_name = "_QFsimd_with_nontemporal_clauseEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
25+
fir.store %arg2 to %11 : !fir.ref<i32>
26+
// CHECK: %[[LOAD:.*]] = fir.load %[[A_DECL]] {nontemporal} : !fir.ref<i32>
27+
%12 = fir.load %2 : !fir.ref<i32>
28+
%13 = fir.load %4 : !fir.ref<i32>
29+
%14 = arith.addi %12, %13 : i32
30+
// CHECK: %[[ADD_VAL:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
31+
// CHECK: fir.store %[[ADD_VAL]] to %[[C_DECL]] {nontemporal} : !fir.ref<i32>
32+
fir.store %14 to %6 : !fir.ref<i32>
33+
omp.yield
34+
}
35+
}
36+
return
37+
}
38+
39+
// CHECK-LABEL: func.func @_QPsimd_nontemporal_allocatable
40+
func.func @_QPsimd_nontemporal_allocatable(%arg0: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "y"}) {
41+
%c1_i32 = arith.constant 1 : i32
42+
%c0 = arith.constant 0 : index
43+
%c100_i32 = arith.constant 100 : i32
44+
%0 = fir.dummy_scope : !fir.dscope
45+
%1 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimd_nontemporal_allocatableEi"}
46+
%2 = fir.declare %1 {uniq_name = "_QFsimd_nontemporal_allocatableEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
47+
// CHECK: %[[X_DECL:.*]] = fir.declare %{{.*}} dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>,
48+
// CHECK-SAME: uniq_name = "_QFsimd_nontemporal_allocatableEx"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
49+
%3 = fir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsimd_nontemporal_allocatableEx"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
50+
%4 = fir.declare %arg1 dummy_scope %0 {uniq_name = "_QFsimd_nontemporal_allocatableEy"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
51+
%5 = fir.convert %c100_i32 : (i32) -> index
52+
%6 = arith.cmpi sgt, %5, %c0 : index
53+
%7 = arith.select %6, %5, %c0 : index
54+
%8 = fir.allocmem !fir.array<?xi32>, %7 {fir.must_be_heap = true, uniq_name = "_QFsimd_nontemporal_allocatableEx.alloc"}
55+
%9 = fir.shape %7 : (index) -> !fir.shape<1>
56+
%10 = fir.embox %8(%9) : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
57+
fir.store %10 to %3 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
58+
// CHECK: omp.simd nontemporal(%[[X_DECL]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) private(@_QFsimd_nontemporal_allocatableEi_private_i32 %2 -> %arg2 : !fir.ref<i32>) {
59+
// CHECK: omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}}) {
60+
omp.simd nontemporal(%3 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) private(@_QFsimd_nontemporal_allocatableEi_private_i32 %2 -> %arg2 : !fir.ref<i32>) {
61+
omp.loop_nest (%arg3) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32) {
62+
%16 = fir.declare %arg2 {uniq_name = "_QFsimd_nontemporal_allocatableEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
63+
fir.store %arg3 to %16 : !fir.ref<i32>
64+
// CHECK: %[[VAL1:.*]] = fir.load %[[X_DECL]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
65+
%17 = fir.load %3 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
66+
%18 = fir.load %16 : !fir.ref<i32>
67+
%19 = fir.convert %18 : (i32) -> i64
68+
// CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[VAL1]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
69+
%20 = fir.box_addr %17 : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
70+
%c0_0 = arith.constant 0 : index
71+
%21:3 = fir.box_dims %17, %c0_0 : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
72+
%22 = fir.shape_shift %21#0, %21#1 : (index, index) -> !fir.shapeshift<1>
73+
// CHECK: %[[ARR_COOR:.*]] = fir.array_coor %[[BOX_ADDR]](%{{.*}}) %{{.*}} : (!fir.heap<!fir.array<?xi32>>, !fir.shapeshift<1>, i64) -> !fir.ref<i32>
74+
%23 = fir.array_coor %20(%22) %19 : (!fir.heap<!fir.array<?xi32>>, !fir.shapeshift<1>, i64) -> !fir.ref<i32>
75+
// CHECK: %[[VAL2:.*]] = fir.load %[[ARR_COOR]] {nontemporal} : !fir.ref<i32>
76+
%24 = fir.load %23 : !fir.ref<i32>
77+
%25 = fir.load %4 : !fir.ref<i32>
78+
%26 = arith.addi %24, %25 : i32
79+
%27 = fir.load %3 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
80+
%28 = fir.load %16 : !fir.ref<i32>
81+
%29 = fir.convert %28 : (i32) -> i64
82+
%30 = fir.box_addr %27 : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
83+
%c0_1 = arith.constant 0 : index
84+
%31:3 = fir.box_dims %27, %c0_1 : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
85+
%32 = fir.shape_shift %31#0, %31#1 : (index, index) -> !fir.shapeshift<1>
86+
%33 = fir.array_coor %30(%32) %29 : (!fir.heap<!fir.array<?xi32>>, !fir.shapeshift<1>, i64) -> !fir.ref<i32>
87+
// CHECK: fir.store %{{.*}} to %{{.*}} {nontemporal} : !fir.ref<i32>
88+
fir.store %26 to %33 : !fir.ref<i32>
89+
omp.yield
90+
}
91+
}
92+
%11 = fir.load %3 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
93+
%12 = fir.box_addr %11 : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
94+
fir.freemem %12 : !fir.heap<!fir.array<?xi32>>
95+
%13 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
96+
%14 = fir.shape %c0 : (index) -> !fir.shape<1>
97+
%15 = fir.embox %13(%14) : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
98+
fir.store %15 to %3 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
99+
return
100+
}
101+

0 commit comments

Comments
 (0)