Skip to content

Commit f84b83e

Browse files
authored
Revert "[flang][cuda] Update stream type for cuf kernel op (#136627)"
This reverts commit 46e7347.
1 parent 46e7347 commit f84b83e

File tree

4 files changed

+22
-13
lines changed

4 files changed

+22
-13
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -254,19 +254,24 @@ def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,
254254
represented by a 0 constant value.
255255
}];
256256

257-
let arguments = (ins Variadic<I32>:$grid, // empty means `*`
258-
Variadic<I32>:$block, // empty means `*`
259-
Optional<fir_ReferenceType>:$stream, Variadic<Index>:$lowerbound,
260-
Variadic<Index>:$upperbound, Variadic<Index>:$step,
261-
OptionalAttr<I64Attr>:$n, Variadic<AnyType>:$reduceOperands,
262-
OptionalAttr<ArrayAttr>:$reduceAttrs);
257+
let arguments = (ins
258+
Variadic<I32>:$grid, // empty means `*`
259+
Variadic<I32>:$block, // empty means `*`
260+
Optional<I32>:$stream,
261+
Variadic<Index>:$lowerbound,
262+
Variadic<Index>:$upperbound,
263+
Variadic<Index>:$step,
264+
OptionalAttr<I64Attr>:$n,
265+
Variadic<AnyType>:$reduceOperands,
266+
OptionalAttr<ArrayAttr>:$reduceAttrs
267+
);
263268

264269
let regions = (region AnyRegion:$region);
265270

266271
let assemblyFormat = [{
267272
`<` `<` `<` custom<CUFKernelValues>($grid, type($grid)) `,`
268273
custom<CUFKernelValues>($block, type($block))
269-
( `,` `stream` `=` $stream^ `:` qualified(type($stream)))? `>` `>` `>`
274+
( `,` `stream` `=` $stream^ )? `>` `>` `>`
270275
( `reduce` `(` $reduceOperands^ `:` type($reduceOperands) `:` $reduceAttrs `)` )?
271276
custom<CUFKernelLoopControl>($region, $lowerbound, type($lowerbound),
272277
$upperbound, type($upperbound), $step, type($step))

flang/lib/Lower/Bridge.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3097,7 +3097,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
30973097

30983098
llvm::SmallVector<mlir::Value> gridValues;
30993099
llvm::SmallVector<mlir::Value> blockValues;
3100-
mlir::Value streamAddr;
3100+
mlir::Value streamValue;
31013101

31023102
if (launchConfig) {
31033103
const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &grid =
@@ -3130,8 +3130,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
31303130
}
31313131

31323132
if (stream)
3133-
streamAddr = fir::getBase(
3134-
genExprAddr(*Fortran::semantics::GetExpr(*stream), stmtCtx));
3133+
streamValue = builder->createConvert(
3134+
loc, builder->getI32Type(),
3135+
fir::getBase(
3136+
genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx)));
31353137
}
31363138

31373139
const auto &outerDoConstruct =
@@ -3265,7 +3267,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
32653267
}
32663268

32673269
auto op = builder->create<cuf::KernelOp>(
3268-
loc, gridValues, blockValues, streamAddr, lbs, ubs, steps, n,
3270+
loc, gridValues, blockValues, streamValue, lbs, ubs, steps, n,
32693271
mlir::ValueRange(reduceOperands), builder->getArrayAttr(reduceAttrs));
32703272
builder->createBlock(&op.getRegion(), op.getRegion().end(), ivTypes,
32713273
ivLocs);

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ llvm::LogicalResult cuf::KernelOp::verify() {
271271
return emitOpError("expect reduce attributes to be ReduceAttr");
272272
}
273273
}
274-
return checkStreamType(*this);
274+
return mlir::success();
275275
}
276276

277277
//===----------------------------------------------------------------------===//

flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ subroutine sub1()
7575
end do
7676
end
7777

78-
! CHECK: cuf.kernel<<<*, *, stream = %[[STREAM]]#0 : !fir.ref<i64>>>>
78+
! CHECK: %[[STREAM_LOAD:.*]] = fir.load %[[STREAM]]#0 : !fir.ref<i64>
79+
! CHECK: %[[STREAM_I32:.*]] = fir.convert %[[STREAM_LOAD]] : (i64) -> i32
80+
! CHECK: cuf.kernel<<<*, *, stream = %[[STREAM_I32]]>>>
7981

8082

8183
! Test lowering with unstructured construct inside.

0 commit comments

Comments
 (0)