Skip to content

Commit 33cb29c

Browse files
authored
[flang][cuda] Use cuf.alloc/cuf.free for local descriptor (#98518)
Local descriptor for cuda allocatable need to be handled on host and device. One solution is to duplicate the descriptor (one on the host and one on the device) and keep them in sync or have the descriptor in managed/unified memory so we don't to take care of any sync. The second solution is probably the one we will implement. In order to have more flexibility on how descriptor representing cuda allocatable are allocated, this patch updates the lowering to use the cuf operations alloc and free to managed them.
1 parent a742693 commit 33cb29c

File tree

5 files changed

+38
-40
lines changed

5 files changed

+38
-40
lines changed

flang/include/flang/Semantics/tools.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,19 +222,15 @@ inline bool HasCUDAAttr(const Symbol &sym) {
222222
}
223223

224224
inline bool NeedCUDAAlloc(const Symbol &sym) {
225-
bool inDeviceSubprogram{IsCUDADeviceContext(&sym.owner())};
226225
if (IsDummy(sym)) {
227226
return false;
228227
}
229228
if (const auto *details{sym.GetUltimate().detailsIf<ObjectEntityDetails>()}) {
230229
if (details->cudaDataAttr() &&
231230
(*details->cudaDataAttr() == common::CUDADataAttr::Device ||
232231
*details->cudaDataAttr() == common::CUDADataAttr::Managed ||
233-
*details->cudaDataAttr() == common::CUDADataAttr::Unified)) {
234-
// Descriptor is allocated on host when in host context.
235-
if (IsAllocatable(sym)) {
236-
return inDeviceSubprogram;
237-
}
232+
*details->cudaDataAttr() == common::CUDADataAttr::Unified ||
233+
*details->cudaDataAttr() == common::CUDADataAttr::Pinned)) {
238234
return true;
239235
}
240236
}

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -715,8 +715,9 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
715715
auto idxTy = builder.getIndexType();
716716
for (mlir::Value sh : elidedShape)
717717
indices.push_back(builder.createConvert(loc, idxTy, sh));
718-
return builder.create<cuf::AllocOp>(loc, ty, nm, symNm, dataAttr, lenParams,
719-
indices);
718+
mlir::Value alloc = builder.create<cuf::AllocOp>(
719+
loc, ty, nm, symNm, dataAttr, lenParams, indices);
720+
return alloc;
720721
}
721722

722723
// Let the builder do all the heavy lifting.
@@ -927,6 +928,19 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
927928
finalizeAtRuntime(converter, var, symMap);
928929
if (mustBeDefaultInitializedAtRuntime(var))
929930
defaultInitializeAtRuntime(converter, var, symMap);
931+
if (Fortran::semantics::NeedCUDAAlloc(var.getSymbol())) {
932+
auto *builder = &converter.getFirOpBuilder();
933+
mlir::Location loc = converter.getCurrentLocation();
934+
fir::ExtendedValue exv =
935+
converter.getSymbolExtendedValue(var.getSymbol(), &symMap);
936+
auto *sym = &var.getSymbol();
937+
converter.getFctCtx().attachCleanup([builder, loc, exv, sym]() {
938+
cuf::DataAttributeAttr dataAttr =
939+
Fortran::lower::translateSymbolCUFDataAttribute(builder->getContext(),
940+
*sym);
941+
builder->create<cuf::FreeOp>(loc, fir::getBase(exv), dataAttr);
942+
});
943+
}
930944
if (std::optional<VariableCleanUp> cleanup =
931945
needDeallocationOrFinalization(var)) {
932946
auto *builder = &converter.getFirOpBuilder();
@@ -950,22 +964,10 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
950964
"trying to deallocate entity not lowered as allocatable");
951965
Fortran::lower::genDeallocateIfAllocated(*converterPtr, *mutableBox,
952966
loc, sym);
967+
953968
});
954969
}
955970
}
956-
if (Fortran::semantics::NeedCUDAAlloc(var.getSymbol())) {
957-
auto *builder = &converter.getFirOpBuilder();
958-
mlir::Location loc = converter.getCurrentLocation();
959-
fir::ExtendedValue exv =
960-
converter.getSymbolExtendedValue(var.getSymbol(), &symMap);
961-
auto *sym = &var.getSymbol();
962-
converter.getFctCtx().attachCleanup([builder, loc, exv, sym]() {
963-
cuf::DataAttributeAttr dataAttr =
964-
Fortran::lower::translateSymbolCUFDataAttribute(builder->getContext(),
965-
*sym);
966-
builder->create<cuf::FreeOp>(loc, fir::getBase(exv), dataAttr);
967-
});
968-
}
969971
}
970972

971973
//===----------------------------------------------------------------===//

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,11 @@ template <typename Op>
5454
static llvm::LogicalResult checkCudaAttr(Op op) {
5555
if (op.getDataAttr() == cuf::DataAttribute::Device ||
5656
op.getDataAttr() == cuf::DataAttribute::Managed ||
57-
op.getDataAttr() == cuf::DataAttribute::Unified)
57+
op.getDataAttr() == cuf::DataAttribute::Unified ||
58+
op.getDataAttr() == cuf::DataAttribute::Pinned)
5859
return mlir::success();
59-
return op.emitOpError("expect device, managed or unified cuda attribute");
60+
return op.emitOpError()
61+
<< "expect device, managed, pinned or unified cuda attribute";
6062
}
6163

6264
llvm::LogicalResult cuf::AllocOp::verify() { return checkCudaAttr(*this); }

flang/test/Fir/cuf-invalid.fir

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,9 @@ func.func @_QPsub1() {
8888

8989
// -----
9090

91-
func.func @_QPsub1() {
92-
// expected-error@+1{{'cuf.alloc' op expect device, managed or unified cuda attribute}}
93-
%0 = cuf.alloc f32 {bindc_name = "r", data_attr = #cuf.cuda<pinned>, uniq_name = "_QFsub1Er"} -> !fir.ref<f32>
94-
cuf.free %0 : !fir.ref<f32> {data_attr = #cuf.cuda<constant>}
95-
return
96-
}
97-
98-
// -----
99-
10091
func.func @_QPsub1() {
10192
%0 = cuf.alloc f32 {bindc_name = "r", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Er"} -> !fir.ref<f32>
102-
// expected-error@+1{{'cuf.free' op expect device, managed or unified cuda attribute}}
93+
// expected-error@+1{{'cuf.free' op expect device, managed, pinned or unified cuda attribute}}
10394
cuf.free %0 : !fir.ref<f32> {data_attr = #cuf.cuda<constant>}
10495
return
10596
}

flang/test/Lower/CUDA/cuda-allocatable.cuf

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ subroutine sub1()
1010
end subroutine
1111

1212
! CHECK-LABEL: func.func @_QPsub1()
13-
! CHECK: %[[BOX:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
13+
! CHECK: %[[BOX:.*]] = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
1414
! CHECK: %[[BOX_DECL:.*]]:2 = hlfir.declare %[[BOX]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
1515
! CHECK: fir.call @_FortranAAllocatableSetBounds
1616
! CHECK: %{{.*}} = cuf.allocate %[[BOX_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
@@ -25,6 +25,7 @@ end subroutine
2525
! CHECK: fir.if %[[NE_C0]] {
2626
! CHECK: %{{.*}} = cuf.deallocate %[[BOX_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
2727
! CHECK: }
28+
! CHECK: cuf.free %[[BOX_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>}
2829

2930
subroutine sub2()
3031
real, allocatable, managed :: a(:)
@@ -35,7 +36,7 @@ subroutine sub2()
3536
end subroutine
3637

3738
! CHECK-LABEL: func.func @_QPsub2()
38-
! CHECK: %[[BOX:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub2Ea"}
39+
! CHECK: %[[BOX:.*]] = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QFsub2Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
3940
! CHECK: %[[BOX_DECL:.*]]:2 = hlfir.declare %[[BOX]] {data_attr = #cuf.cuda<managed>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub2Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
4041
! CHECK: %[[ISTAT:.*]] = fir.alloca i32 {bindc_name = "istat", uniq_name = "_QFsub2Eistat"}
4142
! CHECK: %[[ISTAT_DECL:.*]]:2 = hlfir.declare %[[ISTAT]] {uniq_name = "_QFsub2Eistat"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
@@ -49,6 +50,7 @@ end subroutine
4950
! CHECK: fir.if %{{.*}} {
5051
! CHECK: %{{.*}} = cuf.deallocate %[[BOX_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<managed>} -> i32
5152
! CHECK: }
53+
! CHECK: cuf.free %[[BOX_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<managed>}
5254

5355
subroutine sub3()
5456
integer, allocatable, pinned :: a(:,:)
@@ -57,7 +59,7 @@ subroutine sub3()
5759
end subroutine
5860

5961
! CHECK-LABEL: func.func @_QPsub3()
60-
! CHECK: %[[BOX:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xi32>>> {bindc_name = "a", uniq_name = "_QFsub3Ea"}
62+
! CHECK: %[[BOX:.*]] = cuf.alloc !fir.box<!fir.heap<!fir.array<?x?xi32>>> {bindc_name = "a", data_attr = #cuf.cuda<pinned>, uniq_name = "_QFsub3Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
6163
! CHECK: %[[BOX_DECL:.*]]:2 = hlfir.declare %[[BOX]] {data_attr = #cuf.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub3Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
6264
! CHECK: %[[PLOG:.*]] = fir.alloca !fir.logical<4> {bindc_name = "plog", uniq_name = "_QFsub3Eplog"}
6365
! CHECK: %[[PLOG_DECL:.*]]:2 = hlfir.declare %5 {uniq_name = "_QFsub3Eplog"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
@@ -66,6 +68,7 @@ end subroutine
6668
! CHECK: fir.if %{{.*}} {
6769
! CHECK: %{{.*}} = cuf.deallocate %[[BOX_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>> {data_attr = #cuf.cuda<pinned>} -> i32
6870
! CHECK: }
71+
! CHECK: cuf.free %[[BOX_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>> {data_attr = #cuf.cuda<pinned>}
6972

7073
subroutine sub4()
7174
real, allocatable, device :: a(:)
@@ -74,7 +77,7 @@ subroutine sub4()
7477
end subroutine
7578

7679
! CHECK-LABEL: func.func @_QPsub4()
77-
! CHECK: %[[BOX:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub4Ea"}
80+
! CHECK: %[[BOX:.*]] = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
7881
! CHECK: %[[BOX_DECL:.*]]:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
7982
! CHECK: %[[ISTREAM:.*]] = fir.alloca i32 {bindc_name = "istream", uniq_name = "_QFsub4Eistream"}
8083
! CHECK: %[[ISTREAM_DECL:.*]]:2 = hlfir.declare %[[ISTREAM]] {uniq_name = "_QFsub4Eistream"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
@@ -84,6 +87,7 @@ end subroutine
8487
! CHECK: fir.if %{{.*}} {
8588
! CHECK: %{{.*}} = cuf.deallocate %[[BOX_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
8689
! CHECK: }
90+
! CHECK: cuf.free %[[BOX_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>}
8791

8892
subroutine sub5()
8993
real, allocatable, device :: a(:)
@@ -92,7 +96,7 @@ subroutine sub5()
9296
end subroutine
9397

9498
! CHECK-LABEL: func.func @_QPsub5()
95-
! CHECK: %[[BOX_A:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub5Ea"}
99+
! CHECK: %[[BOX_A:.*]] = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub5Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
96100
! CHECK: %[[BOX_A_DECL:.*]]:2 = hlfir.declare %[[BOX]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
97101
! CHECK: %[[BOX_B:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "b", uniq_name = "_QFsub5Eb"}
98102
! CHECK: %[[BOX_B_DECL:.*]]:2 = hlfir.declare %[[BOX_B]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Eb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
@@ -104,6 +108,7 @@ end subroutine
104108
! CHECK: fir.if %{{.*}} {
105109
! CHECK: %{{.*}} = cuf.deallocate %[[BOX_A_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
106110
! CHECK: }
111+
! CHECK: cuf.free %[[BOX_A_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>}
107112

108113
subroutine sub6()
109114
real, allocatable, device :: a(:)
@@ -112,7 +117,7 @@ subroutine sub6()
112117
end subroutine
113118

114119
! CHECK-LABEL: func.func @_QPsub6()
115-
! CHECK: %[[BOX_A:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub6Ea"}
120+
! CHECK: %[[BOX_A:.*]] = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
116121
! CHECK: %[[BOX_A_DECL:.*]]:2 = hlfir.declare %[[BOX]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub6Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
117122
! CHECK: %[[BOX_B:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "b", uniq_name = "_QFsub6Eb"}
118123
! CHECK: %[[BOX_B_DECL:.*]]:2 = hlfir.declare %[[BOX_B]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub6Eb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
@@ -122,6 +127,7 @@ end subroutine
122127
! CHECK: fir.if %{{.*}} {
123128
! CHECK: %{{.*}} = cuf.deallocate %[[BOX_A_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
124129
! CHECK: }
130+
! CHECK: cuf.free %[[BOX_A_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>}
125131

126132
subroutine sub7()
127133
real, allocatable, device :: a(:)
@@ -133,7 +139,7 @@ subroutine sub7()
133139
end subroutine
134140

135141
! CHECK-LABEL: func.func @_QPsub7()
136-
! CHECK: %[[BOX:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub7Ea"}
142+
! CHECK: %[[BOX:.*]] = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub7Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
137143
! CHECK: %[[BOX_DECL:.*]]:2 = hlfir.declare %[[BOX]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
138144
! CHECK: %[[ERR:.*]] = fir.alloca !fir.char<1,50> {bindc_name = "err", uniq_name = "_QFsub7Eerr"}
139145
! CHECK: %[[ERR_DECL:.*]]:2 = hlfir.declare %[[ERR]] typeparams %{{.*}} {uniq_name = "_QFsub7Eerr"} : (!fir.ref<!fir.char<1,50>>, index) -> (!fir.ref<!fir.char<1,50>>, !fir.ref<!fir.char<1,50>>)
@@ -150,3 +156,4 @@ end subroutine
150156
! CHECK: fir.if %{{.*}} {
151157
! CHECK: %{{.*}} = cuf.deallocate %[[BOX_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
152158
! CHECK: }
159+
! CHECK: cuf.free %[[BOX_DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>}

0 commit comments

Comments
 (0)