Skip to content

Commit 1a4b0ce

Browse files
clementvalIanWood1
authored andcommitted
[flang][cuda] Update stream type for cuf kernel op (llvm#136627)
Update the type of the stream operand to be similar to KernelLaunchOp.
1 parent 1df0244 commit 1a4b0ce

File tree

4 files changed

+13
-22
lines changed

4 files changed

+13
-22
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -254,24 +254,19 @@ def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,
254254
represented by a 0 constant value.
255255
}];
256256

257-
let arguments = (ins
258-
Variadic<I32>:$grid, // empty means `*`
259-
Variadic<I32>:$block, // empty means `*`
260-
Optional<I32>:$stream,
261-
Variadic<Index>:$lowerbound,
262-
Variadic<Index>:$upperbound,
263-
Variadic<Index>:$step,
264-
OptionalAttr<I64Attr>:$n,
265-
Variadic<AnyType>:$reduceOperands,
266-
OptionalAttr<ArrayAttr>:$reduceAttrs
267-
);
257+
let arguments = (ins Variadic<I32>:$grid, // empty means `*`
258+
Variadic<I32>:$block, // empty means `*`
259+
Optional<fir_ReferenceType>:$stream, Variadic<Index>:$lowerbound,
260+
Variadic<Index>:$upperbound, Variadic<Index>:$step,
261+
OptionalAttr<I64Attr>:$n, Variadic<AnyType>:$reduceOperands,
262+
OptionalAttr<ArrayAttr>:$reduceAttrs);
268263

269264
let regions = (region AnyRegion:$region);
270265

271266
let assemblyFormat = [{
272267
`<` `<` `<` custom<CUFKernelValues>($grid, type($grid)) `,`
273268
custom<CUFKernelValues>($block, type($block))
274-
( `,` `stream` `=` $stream^ )? `>` `>` `>`
269+
( `,` `stream` `=` $stream^ `:` qualified(type($stream)))? `>` `>` `>`
275270
( `reduce` `(` $reduceOperands^ `:` type($reduceOperands) `:` $reduceAttrs `)` )?
276271
custom<CUFKernelLoopControl>($region, $lowerbound, type($lowerbound),
277272
$upperbound, type($upperbound), $step, type($step))

flang/lib/Lower/Bridge.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3097,7 +3097,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
30973097

30983098
llvm::SmallVector<mlir::Value> gridValues;
30993099
llvm::SmallVector<mlir::Value> blockValues;
3100-
mlir::Value streamValue;
3100+
mlir::Value streamAddr;
31013101

31023102
if (launchConfig) {
31033103
const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &grid =
@@ -3130,10 +3130,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
31303130
}
31313131

31323132
if (stream)
3133-
streamValue = builder->createConvert(
3134-
loc, builder->getI32Type(),
3135-
fir::getBase(
3136-
genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx)));
3133+
streamAddr = fir::getBase(
3134+
genExprAddr(*Fortran::semantics::GetExpr(*stream), stmtCtx));
31373135
}
31383136

31393137
const auto &outerDoConstruct =
@@ -3267,7 +3265,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
32673265
}
32683266

32693267
auto op = builder->create<cuf::KernelOp>(
3270-
loc, gridValues, blockValues, streamValue, lbs, ubs, steps, n,
3268+
loc, gridValues, blockValues, streamAddr, lbs, ubs, steps, n,
32713269
mlir::ValueRange(reduceOperands), builder->getArrayAttr(reduceAttrs));
32723270
builder->createBlock(&op.getRegion(), op.getRegion().end(), ivTypes,
32733271
ivLocs);

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ llvm::LogicalResult cuf::KernelOp::verify() {
271271
return emitOpError("expect reduce attributes to be ReduceAttr");
272272
}
273273
}
274-
return mlir::success();
274+
return checkStreamType(*this);
275275
}
276276

277277
//===----------------------------------------------------------------------===//

flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,7 @@ subroutine sub1()
7575
end do
7676
end
7777

78-
! CHECK: %[[STREAM_LOAD:.*]] = fir.load %[[STREAM]]#0 : !fir.ref<i64>
79-
! CHECK: %[[STREAM_I32:.*]] = fir.convert %[[STREAM_LOAD]] : (i64) -> i32
80-
! CHECK: cuf.kernel<<<*, *, stream = %[[STREAM_I32]]>>>
78+
! CHECK: cuf.kernel<<<*, *, stream = %[[STREAM]]#0 : !fir.ref<i64>>>>
8179

8280

8381
! Test lowering with unstructured construct inside.

0 commit comments

Comments
 (0)