Skip to content

Commit d1abbb4

Browse files
authored
[flang][cuda] Change induction variable from i32 to index for doconcurrent inside cuf kernel directive (#129924)
Use `index` instead of `i32` for induction variables for doconcurrent inside cuf kernel directive. Regular do loop inside cuf kernel directive also uses `index`: ``` cuf.kernel<<<*, *>>> (%arg0 : index) = ... ```
1 parent 77363f7 commit d1abbb4

File tree

2 files changed

+12
-17
lines changed

2 files changed

+12
-17
lines changed

flang/lib/Lower/Bridge.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3150,10 +3150,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
31503150
loc, 1); // Use index type directly
31513151

31523152
// Ensure lb, ub, and step are of index type using fir.convert
3153-
mlir::Type indexType = builder->getIndexType();
3154-
lb = builder->create<fir::ConvertOp>(loc, indexType, lb);
3155-
ub = builder->create<fir::ConvertOp>(loc, indexType, ub);
3156-
step = builder->create<fir::ConvertOp>(loc, indexType, step);
3153+
lb = builder->create<fir::ConvertOp>(loc, idxTy, lb);
3154+
ub = builder->create<fir::ConvertOp>(loc, idxTy, ub);
3155+
step = builder->create<fir::ConvertOp>(loc, idxTy, step);
31573156

31583157
lbs.push_back(lb);
31593158
ubs.push_back(ub);
@@ -3163,18 +3162,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
31633162

31643163
// Handle induction variable
31653164
mlir::Value ivValue = getSymbolAddress(*name.symbol);
3166-
std::size_t ivTypeSize = name.symbol->size();
3167-
if (ivTypeSize == 0)
3168-
llvm::report_fatal_error("unexpected induction variable size");
3169-
mlir::Type ivTy = builder->getIntegerType(ivTypeSize * 8);
31703165

31713166
if (!ivValue) {
31723167
// DO CONCURRENT induction variables are not mapped yet since they are
31733168
// local to the DO CONCURRENT scope.
31743169
mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
31753170
builder->setInsertionPointToStart(builder->getAllocaBlock());
31763171
ivValue = builder->createTemporaryAlloc(
3177-
loc, ivTy, toStringRef(name.symbol->name()));
3172+
loc, idxTy, toStringRef(name.symbol->name()));
31783173
builder->restoreInsertionPoint(insPt);
31793174
}
31803175

@@ -3186,7 +3181,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
31863181
// Bind the symbol to the declared variable
31873182
bindSymbol(*name.symbol, ivValue);
31883183
ivValues.push_back(ivValue);
3189-
ivTypes.push_back(ivTy);
3184+
ivTypes.push_back(idxTy);
31903185
ivLocs.push_back(loc);
31913186
}
31923187
} else {

flang/test/Lower/CUDA/cuda-doconc.cuf

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ subroutine doconc1
1515
end
1616

1717
! CHECK: func.func @_QPdoconc1() {
18-
! CHECK: %[[DECL:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
18+
! CHECK: %[[DECL:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc1Ei"} : (!fir.ref<index>) -> (!fir.ref<index>, !fir.ref<index>)
1919
! CHECK: cuf.kernel<<<*, *>>>
20-
! CHECK: %{{.*}} = fir.load %[[DECL]]#0 : !fir.ref<i32>
20+
! CHECK: %{{.*}} = fir.load %[[DECL]]#0 : !fir.ref<index>
2121

2222
subroutine doconc2
2323
integer :: i, j, m, n
@@ -32,8 +32,8 @@ subroutine doconc2
3232
end
3333

3434
! CHECK: func.func @_QPdoconc2() {
35-
! CHECK: %[[DECLI:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
36-
! CHECK: %[[DECLJ:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
37-
! CHECK: cuf.kernel<<<*, *>>> (%arg0 : i32, %arg1 : i32) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index) {
38-
! CHECK: %{{.*}} = fir.load %[[DECLI]]#0 : !fir.ref<i32>
39-
! CHECK: %{{.*}} = fir.load %[[DECLJ]]#0 : !fir.ref<i32>
35+
! CHECK: %[[DECLI:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ei"} : (!fir.ref<index>) -> (!fir.ref<index>, !fir.ref<index>)
36+
! CHECK: %[[DECLJ:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ej"} : (!fir.ref<index>) -> (!fir.ref<index>, !fir.ref<index>)
37+
! CHECK: cuf.kernel<<<*, *>>> (%arg0 : index, %arg1 : index) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index) {
38+
! CHECK: %{{.*}} = fir.load %[[DECLI]]#0 : !fir.ref<index>
39+
! CHECK: %{{.*}} = fir.load %[[DECLJ]]#0 : !fir.ref<index>

0 commit comments

Comments
 (0)