Skip to content

Commit 7665d3d

Browse files
ImanHosseiniIman Hosseiniclementval
authored
[flang] Add reductions for CUF Kernels: Lowering (llvm#95184)
* Add reductionOperands and reductionAttrs to cuf's KernelOp. * Parsing is already working and the tree has the info: here I make the Bridge emit the updated KernelOp with reduction information added. * Check |reductionAttrs| = |reductionOperands| in verifier * Add a test @clementval @vzakhari --------- Co-authored-by: Iman Hosseini <[email protected]> Co-authored-by: Valentin Clement (バレンタイン クレメン) <[email protected]>
1 parent ca33796 commit 7665d3d

File tree

4 files changed

+103
-4
lines changed

4 files changed

+103
-4
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
include "flang/Optimizer/Dialect/CUF/CUFDialect.td"
1818
include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
1919
include "flang/Optimizer/Dialect/FIRTypes.td"
20+
include "flang/Optimizer/Dialect/FIRAttr.td"
2021
include "mlir/Interfaces/LoopLikeInterface.td"
2122
include "mlir/IR/BuiltinAttributes.td"
2223

@@ -249,7 +250,9 @@ def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,
249250
Variadic<Index>:$lowerbound,
250251
Variadic<Index>:$upperbound,
251252
Variadic<Index>:$step,
252-
OptionalAttr<I64Attr>:$n
253+
OptionalAttr<I64Attr>:$n,
254+
Variadic<AnyType>:$reduceOperands,
255+
OptionalAttr<ArrayAttr>:$reduceAttrs
253256
);
254257

255258
let regions = (region AnyRegion:$region);
@@ -258,11 +261,29 @@ def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,
258261
`<` `<` `<` custom<CUFKernelValues>($grid, type($grid)) `,`
259262
custom<CUFKernelValues>($block, type($block))
260263
( `,` `stream` `=` $stream^ )? `>` `>` `>`
264+
( `reduce` `(` $reduceOperands^ `:` type($reduceOperands) `:` $reduceAttrs `)` )?
261265
custom<CUFKernelLoopControl>($region, $lowerbound, type($lowerbound),
262266
$upperbound, type($upperbound), $step, type($step))
263267
attr-dict
264268
}];
265269

270+
let extraClassDeclaration = [{
271+
/// Get Number of variadic operands
272+
unsigned getNumOperands(unsigned idx) {
273+
auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(
274+
getOperandSegmentSizeAttr());
275+
return static_cast<unsigned>(segments[idx]);
276+
}
277+
// Get Number of reduction operands
278+
unsigned getNumReduceOperands() {
279+
return getNumOperands(7);
280+
}
281+
/// Does the operation hold operands for reduction variables
282+
bool hasReduceOperands() {
283+
return getNumReduceOperands() > 0;
284+
}
285+
}];
286+
266287
let hasVerifier = 1;
267288
}
268289

flang/lib/Lower/Bridge.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2675,6 +2675,35 @@ class FirConverter : public Fortran::lower::AbstractConverter {
26752675
std::get<2>(dir.t);
26762676
const std::optional<Fortran::parser::ScalarIntExpr> &stream =
26772677
std::get<3>(dir.t);
2678+
const std::list<Fortran::parser::CUFReduction> &cufreds =
2679+
std::get<4>(dir.t);
2680+
2681+
llvm::SmallVector<mlir::Value> reduceOperands;
2682+
llvm::SmallVector<mlir::Attribute> reduceAttrs;
2683+
2684+
for (const Fortran::parser::CUFReduction &cufred : cufreds) {
2685+
fir::ReduceOperationEnum redOpEnum = getReduceOperationEnum(
2686+
std::get<Fortran::parser::ReductionOperator>(cufred.t));
2687+
const std::list<Fortran::parser::Scalar<Fortran::parser::Variable>>
2688+
&scalarvars = std::get<1>(cufred.t);
2689+
for (const Fortran::parser::Scalar<Fortran::parser::Variable> &scalarvar :
2690+
scalarvars) {
2691+
auto reduce_attr =
2692+
fir::ReduceAttr::get(builder->getContext(), redOpEnum);
2693+
reduceAttrs.push_back(reduce_attr);
2694+
const Fortran::parser::Variable &var = scalarvar.thing;
2695+
if (const auto *iDesignator = std::get_if<
2696+
Fortran::common::Indirection<Fortran::parser::Designator>>(
2697+
&var.u)) {
2698+
const Fortran::parser::Designator &designator = iDesignator->value();
2699+
if (const auto *name =
2700+
Fortran::semantics::getDesignatorNameIfDataRef(designator)) {
2701+
auto val = getSymbolAddress(*name->symbol);
2702+
reduceOperands.push_back(val);
2703+
}
2704+
}
2705+
}
2706+
}
26782707

26792708
auto isOnlyStars =
26802709
[&](const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr>
@@ -2777,8 +2806,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
27772806
loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
27782807
}
27792808

2780-
auto op = builder->create<cuf::KernelOp>(loc, gridValues, blockValues,
2781-
streamValue, lbs, ubs, steps, n);
2809+
auto op = builder->create<cuf::KernelOp>(
2810+
loc, gridValues, blockValues, streamValue, lbs, ubs, steps, n,
2811+
mlir::ValueRange(reduceOperands), builder->getArrayAttr(reduceAttrs));
27822812
builder->createBlock(&op.getRegion(), op.getRegion().end(), ivTypes,
27832813
ivLocs);
27842814
mlir::Block &b = op.getRegion().back();

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
1414
#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
1515
#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
16+
#include "flang/Optimizer/Dialect/FIRAttr.h"
1617
#include "flang/Optimizer/Dialect/FIRType.h"
1718
#include "mlir/IR/Attributes.h"
1819
#include "mlir/IR/BuiltinAttributes.h"
@@ -227,7 +228,17 @@ mlir::LogicalResult cuf::KernelOp::verify() {
227228
getLowerbound().size() != getStep().size())
228229
return emitOpError(
229230
"expect same number of values in lowerbound, upperbound and step");
230-
231+
auto reduceAttrs = getReduceAttrs();
232+
std::size_t reduceAttrsSize = reduceAttrs ? reduceAttrs->size() : 0;
233+
if (getReduceOperands().size() != reduceAttrsSize)
234+
return emitOpError("expect same number of values in reduce operands and "
235+
"reduce attributes");
236+
if (reduceAttrs) {
237+
for (const auto &attr : reduceAttrs.value()) {
238+
if (!mlir::isa<fir::ReduceAttr>(attr))
239+
return emitOpError("expect reduce attributes to be ReduceAttr");
240+
}
241+
}
231242
return mlir::success();
232243
}
233244

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
! Test CUDA Fortran kernel do reduction
2+
! RUN: bbc -emit-fir -fcuda -o - %s | FileCheck %s
3+
4+
module mod1
5+
contains
6+
subroutine host_sub()
7+
integer, parameter :: asize = 4
8+
integer, device :: adev(asize)
9+
integer :: ahost(asize)
10+
integer :: q
11+
integer, device :: add_reduce_var
12+
integer, device :: mul_reduce_var
13+
! CHECK: %[[VAL_0:.*]] = fir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Fhost_subEadd_reduce_var"} : (!fir.ref<i32>) -> !fir.ref<i32>
14+
! CHECK: %[[VAL_1:.*]] = fir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Fhost_subEmul_reduce_var"} : (!fir.ref<i32>) -> !fir.ref<i32>
15+
do i = 1, asize
16+
ahost(i) = i
17+
enddo
18+
adev = ahost
19+
add_reduce_var = 0.0
20+
mul_reduce_var = 1.0
21+
! CHECK: {{.*}} reduce(%[[VAL_0:.*]], %[[VAL_1:.*]] : !fir.ref<i32>, !fir.ref<i32> : [#fir.reduce_attr<add>, #fir.reduce_attr<multiply>]) {{.*}}
22+
!$cuf kernel do <<< *, * >>> reduce(+:add_reduce_var) reduce(*:mul_reduce_var)
23+
do i = 1, asize
24+
add_reduce_var = add_reduce_var + adev(i)
25+
mul_reduce_var = mul_reduce_var * adev(i)
26+
end do
27+
q = rsum
28+
ahost = adev
29+
print *, q
30+
end
31+
end
32+
33+
program test
34+
use mod1
35+
implicit none
36+
call host_sub()
37+
end program test

0 commit comments

Comments
 (0)