Skip to content

Commit 3d80163

Browse files
committed
[flang] Implement !DIR$ VECTOR ALWAYS
This patch implements support for the VECTOR ALWAYS directive, which forces vectorization to occurr when possible regardless of a decision by the cost model. This is done by adding an attribute to the branch into the loop in LLVM to indicate that the loop should always be vectorized.
1 parent f5dcfb9 commit 3d80163

File tree

12 files changed

+141
-7
lines changed

12 files changed

+141
-7
lines changed

flang/include/flang/Lower/PFTBuilder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ struct Evaluation : EvaluationVariant {
347347
parser::CharBlock position{};
348348
std::optional<parser::Label> label{};
349349
std::unique_ptr<EvaluationList> evaluationList; // nested evaluations
350+
llvm::SmallVector<const parser::CompilerDirective *> dirs;
350351
Evaluation *parentConstruct{nullptr}; // set for nodes below the top level
351352
Evaluation *lexicalSuccessor{nullptr}; // set for leaf nodes, some directives
352353
Evaluation *controlSuccessor{nullptr}; // set for some leaf nodes

flang/include/flang/Optimizer/Dialect/FIROps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "flang/Optimizer/Dialect/FortranVariableInterface.h"
1717
#include "mlir/Dialect/Arith/IR/Arith.h"
1818
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
1920
#include "mlir/Interfaces/LoopLikeInterface.h"
2021
#include "mlir/Interfaces/SideEffectInterfaces.h"
2122

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2160,7 +2160,8 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
21602160
Variadic<AnyType>:$initArgs,
21612161
OptionalAttr<UnitAttr>:$unordered,
21622162
OptionalAttr<UnitAttr>:$finalValue,
2163-
OptionalAttr<ArrayAttr>:$reduceAttrs
2163+
OptionalAttr<ArrayAttr>:$reduceAttrs,
2164+
OptionalAttr<LoopAnnotationAttr>:$loop_annotation
21642165
);
21652166
let results = (outs Variadic<AnyType>:$results);
21662167
let regions = (region SizedRegion<1>:$region);

flang/include/flang/Parser/dump-parse-tree.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class ParseTreeDumper {
204204
NODE(CompilerDirective, IgnoreTKR)
205205
NODE(CompilerDirective, LoopCount)
206206
NODE(CompilerDirective, AssumeAligned)
207+
NODE(CompilerDirective, VectorAlways)
207208
NODE(CompilerDirective, NameValue)
208209
NODE(CompilerDirective, Unrecognized)
209210
NODE(parser, ComplexLiteralConstant)

flang/include/flang/Parser/parse-tree.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3334,14 +3334,15 @@ struct CompilerDirective {
33343334
TUPLE_CLASS_BOILERPLATE(AssumeAligned);
33353335
std::tuple<common::Indirection<Designator>, uint64_t> t;
33363336
};
3337+
EMPTY_CLASS(VectorAlways);
33373338
struct NameValue {
33383339
TUPLE_CLASS_BOILERPLATE(NameValue);
33393340
std::tuple<Name, std::optional<std::uint64_t>> t;
33403341
};
33413342
EMPTY_CLASS(Unrecognized);
33423343
CharBlock source;
33433344
std::variant<std::list<IgnoreTKR>, LoopCount, std::list<AssumeAligned>,
3344-
std::list<NameValue>, Unrecognized>
3345+
VectorAlways, std::list<NameValue>, Unrecognized>
33453346
u;
33463347
};
33473348

flang/lib/Lower/Bridge.cpp

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,7 +1929,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19291929

19301930
// Increment loop begin code. (Infinite/while code was already generated.)
19311931
if (!infiniteLoop && !whileCondition)
1932-
genFIRIncrementLoopBegin(incrementLoopNestInfo);
1932+
genFIRIncrementLoopBegin(incrementLoopNestInfo, doStmtEval.dirs);
19331933

19341934
// Loop body code.
19351935
auto iter = eval.getNestedEvaluations().begin();
@@ -1974,8 +1974,22 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19741974
return builder->createIntegerConstant(loc, controlType, 1); // step
19751975
}
19761976

1977+
void addLoopAnnotationAttr(IncrementLoopInfo &info) {
1978+
mlir::BoolAttr f = mlir::BoolAttr::get(builder->getContext(), false);
1979+
mlir::LLVM::LoopVectorizeAttr va = mlir::LLVM::LoopVectorizeAttr::get(
1980+
builder->getContext(), f, {}, {}, {}, {}, {}, {});
1981+
mlir::LLVM::AccessGroupAttr ag =
1982+
mlir::LLVM::AccessGroupAttr::get(builder->getContext());
1983+
mlir::LLVM::LoopAnnotationAttr la = mlir::LLVM::LoopAnnotationAttr::get(
1984+
builder->getContext(), {}, va, {}, {}, {}, {}, {}, {}, {}, {}, {}, {},
1985+
{}, {}, {ag});
1986+
info.doLoop.setLoopAnnotationAttr(la);
1987+
}
1988+
19771989
/// Generate FIR to begin a structured or unstructured increment loop nest.
1978-
void genFIRIncrementLoopBegin(IncrementLoopNestInfo &incrementLoopNestInfo) {
1990+
void genFIRIncrementLoopBegin(
1991+
IncrementLoopNestInfo &incrementLoopNestInfo,
1992+
llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) {
19791993
assert(!incrementLoopNestInfo.empty() && "empty loop nest");
19801994
mlir::Location loc = toLocation();
19811995
for (IncrementLoopInfo &info : incrementLoopNestInfo) {
@@ -2040,6 +2054,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
20402054
}
20412055
if (info.hasLocalitySpecs())
20422056
handleLocalitySpecs(info);
2057+
2058+
for (const auto *dir : dirs) {
2059+
std::visit(
2060+
Fortran::common::visitors{
2061+
[&](const Fortran::parser::CompilerDirective::VectorAlways
2062+
&d) { addLoopAnnotationAttr(info); },
2063+
[&](const auto &) {}},
2064+
dir->u);
2065+
}
20432066
continue;
20442067
}
20452068

@@ -2573,8 +2596,30 @@ class FirConverter : public Fortran::lower::AbstractConverter {
25732596
}
25742597
}
25752598

2576-
void genFIR(const Fortran::parser::CompilerDirective &) {
2577-
// TODO
2599+
void attachLoopDirective(const Fortran::parser::CompilerDirective &dir,
2600+
Fortran::lower::pft::Evaluation *e) {
2601+
while (e->isDirective()) {
2602+
e = e->lexicalSuccessor;
2603+
}
2604+
2605+
if (e->isA<Fortran::parser::NonLabelDoStmt>()) {
2606+
e->dirs.push_back(&dir);
2607+
} else {
2608+
fir::emitFatalError(toLocation(),
2609+
"loop directive must appear before a loop");
2610+
}
2611+
}
2612+
2613+
void genFIR(const Fortran::parser::CompilerDirective &dir) {
2614+
Fortran::lower::pft::Evaluation &eval = getEval();
2615+
2616+
std::visit(
2617+
Fortran::common::visitors{
2618+
[&](const Fortran::parser::CompilerDirective::VectorAlways &) {
2619+
attachLoopDirective(dir, &eval);
2620+
},
2621+
[&](const auto &) {}},
2622+
dir.u);
25782623
}
25792624

25802625
void genFIR(const Fortran::parser::OpenACCConstruct &acc) {

flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,14 @@ class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
132132
auto comparison = rewriter.create<mlir::arith::CmpIOp>(
133133
loc, arith::CmpIPredicate::sgt, itersLeft, zero);
134134

135-
rewriter.create<mlir::cf::CondBranchOp>(
135+
auto cond = rewriter.create<mlir::cf::CondBranchOp>(
136136
loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock,
137137
llvm::ArrayRef<mlir::Value>());
138138

139+
if (auto ann = loop.getLoopAnnotation()) {
140+
cond->setAttr("loop_annotation", *ann);
141+
}
142+
139143
// The result of the loop operation is the values of the condition block
140144
// arguments except the induction variable on the last iteration.
141145
auto args = loop.getFinalValue()

flang/lib/Parser/Fortran-parsers.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,10 +1276,13 @@ constexpr auto loopCount{
12761276
constexpr auto assumeAligned{"ASSUME_ALIGNED" >>
12771277
optionalList(construct<CompilerDirective::AssumeAligned>(
12781278
indirect(designator), ":"_tok >> digitString64))};
1279+
constexpr auto vectorAlways{
1280+
"VECTOR ALWAYS" >> construct<CompilerDirective::VectorAlways>()};
12791281
TYPE_PARSER(beginDirective >> "DIR$ "_tok >>
12801282
sourced((construct<CompilerDirective>(ignore_tkr) ||
12811283
construct<CompilerDirective>(loopCount) ||
12821284
construct<CompilerDirective>(assumeAligned) ||
1285+
construct<CompilerDirective>(vectorAlways) ||
12831286
construct<CompilerDirective>(
12841287
many(construct<CompilerDirective::NameValue>(
12851288
name, maybe(("="_tok || ":"_tok) >> digitString64))))) /

flang/lib/Parser/unparse.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,6 +1828,9 @@ class UnparseVisitor {
18281828
Word("!DIR$ ASSUME_ALIGNED ");
18291829
Walk(" ", assumeAligned, ", ");
18301830
},
1831+
[&](const CompilerDirective::VectorAlways &valways) {
1832+
Word("!DIR$ VECTOR ALWAYS");
1833+
},
18311834
[&](const std::list<CompilerDirective::NameValue> &names) {
18321835
Walk("!DIR$ ", names, " ");
18331836
},

flang/lib/Semantics/resolve-names.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8886,6 +8886,9 @@ void ResolveNamesVisitor::Post(const parser::AssignedGotoStmt &x) {
88868886
}
88878887

88888888
void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) {
8889+
if (const auto *dir{
8890+
std::get_if<parser::CompilerDirective::VectorAlways>(&x.u)})
8891+
return;
88898892
if (const auto *tkr{
88908893
std::get_if<std::list<parser::CompilerDirective::IgnoreTKR>>(&x.u)}) {
88918894
if (currScope().IsTopLevel() ||

flang/test/Fir/vector-always.fir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: %flang_fc1 -emit-llvm -o - %s | FileCheck %s
2+
3+
#access_group = #llvm.access_group<id = distinct[0]<>>
4+
#loop_vectorize = #llvm.loop_vectorize<disable = false>
5+
#loop_annotation = #llvm.loop_annotation<vectorize = #loop_vectorize, parallelAccesses = #access_group>
6+
7+
// CHECK-LABEL: @vector_always_
8+
// CHECK: br i1 {{.*}}, label {{.*}}, label {{.*}}, !llvm.loop ![[ANNOTATION:.*]]
9+
func.func @_QPvector_always() {
10+
%c1 = arith.constant 1 : index
11+
%c10_i32 = arith.constant 10 : i32
12+
%c1_i32 = arith.constant 1 : i32
13+
%c10 = arith.constant 10 : index
14+
%0 = fir.alloca !fir.array<10xi32> {bindc_name = "a", uniq_name = "_QFvector_alwaysEa"}
15+
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
16+
%2 = fir.declare %0(%1) {uniq_name = "_QFvector_alwaysEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
17+
%3 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFvector_alwaysEi"}
18+
%4 = fir.declare %3 {uniq_name = "_QFvector_alwaysEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
19+
%5 = fir.convert %c1_i32 : (i32) -> index
20+
%6 = fir.convert %c10_i32 : (i32) -> index
21+
%7 = fir.convert %5 : (index) -> i32
22+
%8:2 = fir.do_loop %arg0 = %5 to %6 step %c1 iter_args(%arg1 = %7) -> (index, i32) attributes {loop_annotation = #loop_annotation} {
23+
fir.store %arg1 to %4 : !fir.ref<i32>
24+
%9 = fir.load %4 : !fir.ref<i32>
25+
%10 = fir.load %4 : !fir.ref<i32>
26+
%11 = fir.convert %10 : (i32) -> i64
27+
%12 = fir.array_coor %2(%1) %11 : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, i64) -> !fir.ref<i32>
28+
fir.store %9 to %12 : !fir.ref<i32>
29+
%13 = arith.addi %arg0, %c1 : index
30+
%14 = fir.convert %c1 : (index) -> i32
31+
%15 = fir.load %4 : !fir.ref<i32>
32+
%16 = arith.addi %15, %14 : i32
33+
fir.result %13, %16 : index, i32
34+
}
35+
fir.store %8#1 to %4 : !fir.ref<i32>
36+
return
37+
}
38+
39+
// CHECK: ![[ANNOTATION]] = distinct !{![[ANNOTATION]], ![[VECTORIZE:.*]], ![[PAR_ACCESS:.*]]}
40+
// CHECK: ![[VECTORIZE]] = !{!"llvm.loop.vectorize.enable", i1 true}
41+
// CHECK: ![[PAR_ACCESS]] = !{!"llvm.loop.parallel_accesses", ![[DISTINCT:.*]]}
42+
// CHECK: ![[DISTINCT]] = distinct !{}

flang/test/Lower/vector-always.f90

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
! RUN: %flang_fc1 -emit-fir -o - %s | FileCheck %s
2+
3+
! CHECK: #access_group = #llvm.access_group<id = distinct[0]<>>
4+
! CHECK: #access_group1 = #llvm.access_group<id = distinct[1]<>>
5+
! CHECK: #loop_vectorize = #llvm.loop_vectorize<disable = false>
6+
! CHECK: #loop_annotation = #llvm.loop_annotation<vectorize = #loop_vectorize, parallelAccesses = #access_group>
7+
! CHECK: #loop_annotation1 = #llvm.loop_annotation<vectorize = #loop_vectorize, parallelAccesses = #access_group1>
8+
9+
! CHECK-LABEL: vector_always
10+
subroutine vector_always
11+
integer :: a(10)
12+
!dir$ vector always
13+
!CHECK: fir.do_loop {{.*}} attributes {loop_annotation = #loop_annotation}
14+
do i=1,10
15+
a(i)=i
16+
end do
17+
end subroutine vector_always
18+
19+
20+
! CHECK-LABEL: intermediate_directive
21+
subroutine intermediate_directive
22+
integer :: a(10)
23+
!dir$ vector always
24+
!dir$ unknown
25+
!CHECK: fir.do_loop {{.*}} attributes {loop_annotation = #loop_annotation1}
26+
do i=1,10
27+
a(i)=i
28+
end do
29+
end subroutine intermediate_directive

0 commit comments

Comments
 (0)