Skip to content

Commit be1ca42

Browse files
committed
[flang] Implement !DIR$ VECTOR ALWAYS
This patch implements support for the VECTOR ALWAYS directive, which forces vectorization to occurr when possible regardless of a decision by the cost model. This is done by adding an attribute to the branch into the loop in LLVM to indicate that the loop should always be vectorized.
1 parent 2f1e232 commit be1ca42

File tree

12 files changed

+140
-7
lines changed

12 files changed

+140
-7
lines changed

flang/include/flang/Lower/PFTBuilder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ struct Evaluation : EvaluationVariant {
347347
parser::CharBlock position{};
348348
std::optional<parser::Label> label{};
349349
std::unique_ptr<EvaluationList> evaluationList; // nested evaluations
350+
llvm::SmallVector<const parser::CompilerDirective *> dirs;
350351
Evaluation *parentConstruct{nullptr}; // set for nodes below the top level
351352
Evaluation *lexicalSuccessor{nullptr}; // set for leaf nodes, some directives
352353
Evaluation *controlSuccessor{nullptr}; // set for some leaf nodes

flang/include/flang/Optimizer/Dialect/FIROps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "flang/Optimizer/Dialect/FortranVariableInterface.h"
1717
#include "mlir/Dialect/Arith/IR/Arith.h"
1818
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
1920
#include "mlir/Interfaces/LoopLikeInterface.h"
2021
#include "mlir/Interfaces/SideEffectInterfaces.h"
2122

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2096,7 +2096,8 @@ def fir_DoLoopOp : region_Op<"do_loop",
20962096
Index:$step,
20972097
Variadic<AnyType>:$initArgs,
20982098
OptionalAttr<UnitAttr>:$unordered,
2099-
OptionalAttr<UnitAttr>:$finalValue
2099+
OptionalAttr<UnitAttr>:$finalValue,
2100+
OptionalAttr<LoopAnnotationAttr>:$loop_annotation
21002101
);
21012102
let results = (outs Variadic<AnyType>:$results);
21022103
let regions = (region SizedRegion<1>:$region);

flang/include/flang/Parser/dump-parse-tree.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ class ParseTreeDumper {
206206
NODE(CompilerDirective, IgnoreTKR)
207207
NODE(CompilerDirective, LoopCount)
208208
NODE(CompilerDirective, AssumeAligned)
209+
NODE(CompilerDirective, VectorAlways)
209210
NODE(CompilerDirective, NameValue)
210211
NODE(CompilerDirective, Unrecognized)
211212
NODE(parser, ComplexLiteralConstant)

flang/include/flang/Parser/parse-tree.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3316,14 +3316,15 @@ struct CompilerDirective {
33163316
TUPLE_CLASS_BOILERPLATE(AssumeAligned);
33173317
std::tuple<common::Indirection<Designator>, uint64_t> t;
33183318
};
3319+
EMPTY_CLASS(VectorAlways);
33193320
struct NameValue {
33203321
TUPLE_CLASS_BOILERPLATE(NameValue);
33213322
std::tuple<Name, std::optional<std::uint64_t>> t;
33223323
};
33233324
EMPTY_CLASS(Unrecognized);
33243325
CharBlock source;
33253326
std::variant<std::list<IgnoreTKR>, LoopCount, std::list<AssumeAligned>,
3326-
std::list<NameValue>, Unrecognized>
3327+
VectorAlways, std::list<NameValue>, Unrecognized>
33273328
u;
33283329
};
33293330

flang/lib/Lower/Bridge.cpp

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1881,7 +1881,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
18811881

18821882
// Increment loop begin code. (Infinite/while code was already generated.)
18831883
if (!infiniteLoop && !whileCondition)
1884-
genFIRIncrementLoopBegin(incrementLoopNestInfo);
1884+
genFIRIncrementLoopBegin(incrementLoopNestInfo, doStmtEval.dirs);
18851885

18861886
// Loop body code.
18871887
auto iter = eval.getNestedEvaluations().begin();
@@ -1926,8 +1926,22 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19261926
return builder->createIntegerConstant(loc, controlType, 1); // step
19271927
}
19281928

1929+
void addLoopAnnotationAttr(IncrementLoopInfo &info) {
1930+
mlir::BoolAttr f = mlir::BoolAttr::get(builder->getContext(), false);
1931+
mlir::LLVM::LoopVectorizeAttr va = mlir::LLVM::LoopVectorizeAttr::get(
1932+
builder->getContext(), f, {}, {}, {}, {}, {}, {});
1933+
mlir::LLVM::AccessGroupAttr ag =
1934+
mlir::LLVM::AccessGroupAttr::get(builder->getContext());
1935+
mlir::LLVM::LoopAnnotationAttr la = mlir::LLVM::LoopAnnotationAttr::get(
1936+
builder->getContext(), {}, va, {}, {}, {}, {}, {}, {}, {}, {}, {}, {},
1937+
{}, {}, {ag});
1938+
info.doLoop.setLoopAnnotationAttr(la);
1939+
}
1940+
19291941
/// Generate FIR to begin a structured or unstructured increment loop nest.
1930-
void genFIRIncrementLoopBegin(IncrementLoopNestInfo &incrementLoopNestInfo) {
1942+
void genFIRIncrementLoopBegin(
1943+
IncrementLoopNestInfo &incrementLoopNestInfo,
1944+
llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) {
19311945
assert(!incrementLoopNestInfo.empty() && "empty loop nest");
19321946
mlir::Location loc = toLocation();
19331947
for (IncrementLoopInfo &info : incrementLoopNestInfo) {
@@ -1978,6 +1992,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19781992
}
19791993
if (info.hasLocalitySpecs())
19801994
handleLocalitySpecs(info);
1995+
1996+
for (const auto *dir : dirs) {
1997+
std::visit(
1998+
Fortran::common::visitors{
1999+
[&](const Fortran::parser::CompilerDirective::VectorAlways
2000+
&d) { addLoopAnnotationAttr(info); },
2001+
[&](const auto &) {}},
2002+
dir->u);
2003+
}
19812004
continue;
19822005
}
19832006

@@ -2508,8 +2531,29 @@ class FirConverter : public Fortran::lower::AbstractConverter {
25082531
}
25092532
}
25102533

2511-
void genFIR(const Fortran::parser::CompilerDirective &) {
2512-
// TODO
2534+
void attachLoopDirective(const Fortran::parser::CompilerDirective &dir,
2535+
Fortran::lower::pft::Evaluation *e) {
2536+
while (e->isDirective()) {
2537+
e = e->lexicalSuccessor;
2538+
}
2539+
2540+
if (e->isA<Fortran::parser::NonLabelDoStmt>()) {
2541+
e->dirs.push_back(&dir);
2542+
} else {
2543+
fir::emitFatalError(toLocation(), "loop directive must appear before a loop");
2544+
}
2545+
}
2546+
2547+
void genFIR(const Fortran::parser::CompilerDirective &dir) {
2548+
Fortran::lower::pft::Evaluation &eval = getEval();
2549+
2550+
std::visit(
2551+
Fortran::common::visitors{
2552+
[&](const Fortran::parser::CompilerDirective::VectorAlways &) {
2553+
attachLoopDirective(dir, &eval);
2554+
},
2555+
[&](const auto &) {}},
2556+
dir.u);
25132557
}
25142558

25152559
void genFIR(const Fortran::parser::OpenACCConstruct &acc) {

flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,14 @@ class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
132132
auto comparison = rewriter.create<mlir::arith::CmpIOp>(
133133
loc, arith::CmpIPredicate::sgt, itersLeft, zero);
134134

135-
rewriter.create<mlir::cf::CondBranchOp>(
135+
auto cond = rewriter.create<mlir::cf::CondBranchOp>(
136136
loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock,
137137
llvm::ArrayRef<mlir::Value>());
138138

139+
if (auto ann = loop.getLoopAnnotation()) {
140+
cond->setAttr("loop_annotation", *ann);
141+
}
142+
139143
// The result of the loop operation is the values of the condition block
140144
// arguments except the induction variable on the last iteration.
141145
auto args = loop.getFinalValue()

flang/lib/Parser/Fortran-parsers.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,10 +1276,13 @@ constexpr auto loopCount{
12761276
constexpr auto assumeAligned{"ASSUME_ALIGNED" >>
12771277
optionalList(construct<CompilerDirective::AssumeAligned>(
12781278
indirect(designator), ":"_tok >> digitString64))};
1279+
constexpr auto vectorAlways{
1280+
"VECTOR ALWAYS" >> construct<CompilerDirective::VectorAlways>()};
12791281
TYPE_PARSER(beginDirective >> "DIR$ "_tok >>
12801282
sourced((construct<CompilerDirective>(ignore_tkr) ||
12811283
construct<CompilerDirective>(loopCount) ||
12821284
construct<CompilerDirective>(assumeAligned) ||
1285+
construct<CompilerDirective>(vectorAlways) ||
12831286
construct<CompilerDirective>(
12841287
many(construct<CompilerDirective::NameValue>(
12851288
name, maybe(("="_tok || ":"_tok) >> digitString64))))) /

flang/lib/Parser/unparse.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,6 +1824,9 @@ class UnparseVisitor {
18241824
Word("!DIR$ ASSUME_ALIGNED ");
18251825
Walk(" ", assumeAligned, ", ");
18261826
},
1827+
[&](const CompilerDirective::VectorAlways &valways) {
1828+
Word("!DIR$ VECTOR ALWAYS");
1829+
},
18271830
[&](const std::list<CompilerDirective::NameValue> &names) {
18281831
Walk("!DIR$ ", names, " ");
18291832
},

flang/lib/Semantics/resolve-names.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8854,6 +8854,9 @@ void ResolveNamesVisitor::Post(const parser::AssignedGotoStmt &x) {
88548854
}
88558855

88568856
void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
8857+
if (const auto *dir{
8858+
std::get_if<parser::CompilerDirective::VectorAlways>(&x.u)})
8859+
return;
88578860
if (const auto *tkr{
88588861
std::get_if<std::list<parser::CompilerDirective::IgnoreTKR>>(&x.u)}) {
88598862
if (currScope().IsTopLevel() ||

flang/test/Fir/vector-always.fir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: %flang_fc1 -emit-llvm -o - %s | FileCheck %s
2+
3+
#access_group = #llvm.access_group<id = distinct[0]<>>
4+
#loop_vectorize = #llvm.loop_vectorize<disable = false>
5+
#loop_annotation = #llvm.loop_annotation<vectorize = #loop_vectorize, parallelAccesses = #access_group>
6+
7+
// CHECK-LABEL: @vector_always_
8+
// CHECK: br i1 {{.*}}, label {{.*}}, label {{.*}}, !llvm.loop ![[ANNOTATION:.*]]
9+
func.func @_QPvector_always() {
10+
%c1 = arith.constant 1 : index
11+
%c10_i32 = arith.constant 10 : i32
12+
%c1_i32 = arith.constant 1 : i32
13+
%c10 = arith.constant 10 : index
14+
%0 = fir.alloca !fir.array<10xi32> {bindc_name = "a", uniq_name = "_QFvector_alwaysEa"}
15+
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
16+
%2 = fir.declare %0(%1) {uniq_name = "_QFvector_alwaysEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
17+
%3 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFvector_alwaysEi"}
18+
%4 = fir.declare %3 {uniq_name = "_QFvector_alwaysEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
19+
%5 = fir.convert %c1_i32 : (i32) -> index
20+
%6 = fir.convert %c10_i32 : (i32) -> index
21+
%7 = fir.convert %5 : (index) -> i32
22+
%8:2 = fir.do_loop %arg0 = %5 to %6 step %c1 iter_args(%arg1 = %7) -> (index, i32) attributes {loop_annotation = #loop_annotation} {
23+
fir.store %arg1 to %4 : !fir.ref<i32>
24+
%9 = fir.load %4 : !fir.ref<i32>
25+
%10 = fir.load %4 : !fir.ref<i32>
26+
%11 = fir.convert %10 : (i32) -> i64
27+
%12 = fir.array_coor %2(%1) %11 : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, i64) -> !fir.ref<i32>
28+
fir.store %9 to %12 : !fir.ref<i32>
29+
%13 = arith.addi %arg0, %c1 : index
30+
%14 = fir.convert %c1 : (index) -> i32
31+
%15 = fir.load %4 : !fir.ref<i32>
32+
%16 = arith.addi %15, %14 : i32
33+
fir.result %13, %16 : index, i32
34+
}
35+
fir.store %8#1 to %4 : !fir.ref<i32>
36+
return
37+
}
38+
39+
// CHECK: ![[ANNOTATION]] = distinct !{![[ANNOTATION]], ![[VECTORIZE:.*]], ![[PAR_ACCESS:.*]]}
40+
// CHECK: ![[VECTORIZE]] = !{!"llvm.loop.vectorize.enable", i1 true}
41+
// CHECK: ![[PAR_ACCESS]] = !{!"llvm.loop.parallel_accesses", ![[DISTINCT:.*]]}
42+
// CHECK: ![[DISTINCT]] = distinct !{}

flang/test/Lower/vector-always.f90

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
! RUN: %flang_fc1 -emit-fir -o - %s | FileCheck %s
2+
3+
! CHECK: #access_group = #llvm.access_group<id = distinct[0]<>>
4+
! CHECK: #access_group1 = #llvm.access_group<id = distinct[1]<>>
5+
! CHECK: #loop_vectorize = #llvm.loop_vectorize<disable = false>
6+
! CHECK: #loop_annotation = #llvm.loop_annotation<vectorize = #loop_vectorize, parallelAccesses = #access_group>
7+
! CHECK: #loop_annotation1 = #llvm.loop_annotation<vectorize = #loop_vectorize, parallelAccesses = #access_group1>
8+
9+
! CHECK-LABEL: vector_always
10+
subroutine vector_always
11+
integer :: a(10)
12+
!dir$ vector always
13+
!CHECK: fir.do_loop {{.*}} attributes {loop_annotation = #loop_annotation}
14+
do i=1,10
15+
a(i)=i
16+
end do
17+
end subroutine vector_always
18+
19+
20+
! CHECK-LABEL: intermediate_directive
21+
subroutine intermediate_directive
22+
integer :: a(10)
23+
!dir$ vector always
24+
!dir$ unknown
25+
!CHECK: fir.do_loop {{.*}} attributes {loop_annotation = #loop_annotation1}
26+
do i=1,10
27+
a(i)=i
28+
end do
29+
end subroutine intermediate_directive

0 commit comments

Comments
 (0)