Skip to content

Commit 8ee7d97

Browse files
authored
[flang][cuda] Add fir.cuda_allocate operation (#88586)
Allocatable with cuda device attribute have special semantic for the allocate statement. In flang the allocate statement is lowered to a sequence of runtime call initializing the descriptor and then allocating the descriptor data. This new operation will replace the last runtime call and abstract all the device memory allocation needed. The lowering patch will follow.
1 parent 00ae4b7 commit 8ee7d97

File tree

5 files changed

+172
-0
lines changed

5 files changed

+172
-0
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3190,4 +3190,36 @@ def fir_CUDADataTransferOp : fir_Op<"cuda_data_transfer", []> {
31903190
}];
31913191
}
31923192

3193+
def fir_CUDAAllocateOp : fir_Op<"cuda_allocate", [AttrSizedOperandSegments,
3194+
MemoryEffects<[MemAlloc<DefaultResource>]>]> {
3195+
let summary = "Perform the device allocation of data of an allocatable";
3196+
3197+
let description = [{
3198+
The fir.cuda_allocate operation performs the allocation on the device
3199+
of the data of an allocatable. The descriptor passed to the operation
3200+
is initialized before with the standard flang runtime calls.
3201+
}];
3202+
3203+
let arguments = (ins Arg<AnyRefOrBoxType, "", [MemWrite]>:$box,
3204+
Arg<Optional<AnyRefOrBoxType>, "", [MemWrite]>:$errmsg,
3205+
Optional<AnyIntegerType>:$stream,
3206+
Arg<Optional<AnyRefOrBoxType>, "", [MemWrite]>:$pinned,
3207+
Arg<Optional<AnyRefOrBoxType>, "", [MemRead]>:$source,
3208+
fir_CUDADataAttributeAttr:$cuda_attr,
3209+
UnitAttr:$hasStat);
3210+
3211+
let results = (outs AnyIntegerType:$stat);
3212+
3213+
let assemblyFormat = [{
3214+
$box `:` qualified(type($box))
3215+
( `source` `(` $source^ `:` qualified(type($source) )`)` )?
3216+
( `errmsg` `(` $errmsg^ `:` type($errmsg) `)` )?
3217+
( `stream` `(` $stream^ `:` type($stream) `)` )?
3218+
( `pinned` `(` $pinned^ `:` type($pinned) `)` )?
3219+
attr-dict `->` type($stat)
3220+
}];
3221+
3222+
let hasVerifier = 1;
3223+
}
3224+
31933225
#endif

flang/include/flang/Optimizer/Dialect/FIRTypes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,7 @@ def AnyRefOrBoxLike : TypeConstraint<Or<[AnyReferenceLike.predicate,
625625
def AnyRefOrBox : TypeConstraint<Or<[fir_ReferenceType.predicate,
626626
fir_HeapType.predicate, fir_PointerType.predicate,
627627
IsBaseBoxTypePred]>, "any reference or box">;
628+
def AnyRefOrBoxType : Type<AnyRefOrBox.predicate, "any legal ref or box type">;
628629

629630
def AnyShapeLike : TypeConstraint<Or<[fir_ShapeType.predicate,
630631
fir_ShapeShiftType.predicate]>, "any legal shape type">;

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3993,6 +3993,25 @@ mlir::LogicalResult fir::CUDAKernelOp::verify() {
39933993
return mlir::success();
39943994
}
39953995

3996+
mlir::LogicalResult fir::CUDAAllocateOp::verify() {
3997+
if (getPinned() && getStream())
3998+
return emitOpError("pinned and stream cannot appears at the same time");
3999+
if (!fir::unwrapRefType(getBox().getType()).isa<fir::BaseBoxType>())
4000+
return emitOpError(
4001+
"expect box to be a reference to/or a class or box type value");
4002+
if (getSource() &&
4003+
!fir::unwrapRefType(getSource().getType()).isa<fir::BaseBoxType>())
4004+
return emitOpError(
4005+
"expect source to be a reference to/or a class or box type value");
4006+
if (getErrmsg() &&
4007+
!fir::unwrapRefType(getErrmsg().getType()).isa<fir::BoxType>())
4008+
return emitOpError(
4009+
"expect errmsg to be a reference to/or a box type value");
4010+
if (getErrmsg() && !getHasStat())
4011+
return emitOpError("expect stat attribute when errmsg is provided");
4012+
return mlir::success();
4013+
}
4014+
39964015
//===----------------------------------------------------------------------===//
39974016
// FIROpsDialect
39984017
//===----------------------------------------------------------------------===//

flang/test/Fir/cuf-invalid.fir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: fir-opt -split-input-file -verify-diagnostics %s
2+
3+
func.func @_QPsub1() {
4+
%0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
5+
%1 = fir.alloca i32
6+
%pinned = fir.alloca i1
7+
%4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
8+
%11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
9+
%s = fir.load %1 : !fir.ref<i32>
10+
// expected-error@+1{{'fir.cuda_allocate' op pinned and stream cannot appears at the same time}}
11+
%13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> stream(%s : i32) pinned(%pinned : !fir.ref<i1>) {cuda_attr = #fir.cuda<device>} -> i32
12+
return
13+
}
14+
15+
// -----
16+
17+
func.func @_QPsub1() {
18+
%1 = fir.alloca i32
19+
// expected-error@+1{{'fir.cuda_allocate' op expect box to be a reference to/or a class or box type value}}
20+
%2 = fir.cuda_allocate %1 : !fir.ref<i32> {cuda_attr = #fir.cuda<device>} -> i32
21+
return
22+
}
23+
24+
// -----
25+
26+
func.func @_QPsub1() {
27+
%0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
28+
%4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
29+
%c100 = arith.constant 100 : index
30+
%7 = fir.alloca !fir.char<1,100> {bindc_name = "msg", uniq_name = "_QFsub1Emsg"}
31+
%8:2 = hlfir.declare %7 typeparams %c100 {uniq_name = "_QFsub1Emsg"} : (!fir.ref<!fir.char<1,100>>, index) -> (!fir.ref<!fir.char<1,100>>, !fir.ref<!fir.char<1,100>>)
32+
%9 = fir.embox %8#1 : (!fir.ref<!fir.char<1,100>>) -> !fir.box<!fir.char<1,100>>
33+
%11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
34+
%16 = fir.convert %9 : (!fir.box<!fir.char<1,100>>) -> !fir.box<none>
35+
// expected-error@+1{{'fir.cuda_allocate' op expect stat attribute when errmsg is provided}}
36+
%13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> errmsg(%16 : !fir.box<none>) {cuda_attr = #fir.cuda<device>} -> i32
37+
return
38+
}
39+
40+
// -----
41+
42+
func.func @_QPsub1() {
43+
%0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
44+
%4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
45+
%1 = fir.alloca i32
46+
%11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
47+
// expected-error@+1{{'fir.cuda_allocate' op expect errmsg to be a reference to/or a box type value}}
48+
%13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> errmsg(%1 : !fir.ref<i32>) {cuda_attr = #fir.cuda<device>, hasStat} -> i32
49+
return
50+
}

flang/test/Fir/cuf.mlir

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// RUN: fir-opt --split-input-file %s | fir-opt --split-input-file | FileCheck %s
2+
3+
// Simple round trip test of operations.
4+
5+
func.func @_QPsub1() {
6+
%0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
7+
%4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
8+
%11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
9+
%13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> {cuda_attr = #fir.cuda<device>} -> i32
10+
return
11+
}
12+
13+
// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> {cuda_attr = #fir.cuda<device>} -> i32
14+
15+
// -----
16+
17+
func.func @_QPsub1() {
18+
%0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
19+
%1 = fir.alloca i32
20+
%4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
21+
%11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
22+
%s = fir.load %1 : !fir.ref<i32>
23+
%13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> stream(%s : i32) {cuda_attr = #fir.cuda<device>} -> i32
24+
return
25+
}
26+
27+
// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> stream(%{{.*}} : i32) {cuda_attr = #fir.cuda<device>} -> i32
28+
29+
// -----
30+
31+
func.func @_QPsub1() {
32+
%0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
33+
%1 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "b", uniq_name = "_QFsub1Eb"}
34+
%4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
35+
%5:2 = hlfir.declare %1 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
36+
%11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
37+
%12 = fir.convert %5#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
38+
%13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> source(%12 : !fir.ref<!fir.box<none>>) {cuda_attr = #fir.cuda<device>} -> i32
39+
return
40+
}
41+
42+
// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> source(%{{.*}} : !fir.ref<!fir.box<none>>) {cuda_attr = #fir.cuda<device>} -> i32
43+
44+
// -----
45+
46+
func.func @_QPsub1() {
47+
%0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
48+
%pinned = fir.alloca i1
49+
%4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
50+
%11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
51+
%13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> pinned(%pinned : !fir.ref<i1>) {cuda_attr = #fir.cuda<device>} -> i32
52+
return
53+
}
54+
55+
// CHECK: fir.cuda_allocate %{{.*}} : !fir.ref<!fir.box<none>> pinned(%{{.*}} : !fir.ref<i1>) {cuda_attr = #fir.cuda<device>} -> i32
56+
57+
// -----
58+
59+
func.func @_QPsub1() {
60+
%0 = fir.alloca !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", uniq_name = "_QFsub1Ea"}
61+
%4:2 = hlfir.declare %0 {cuda_attr = #fir.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
62+
%c100 = arith.constant 100 : index
63+
%7 = fir.alloca !fir.char<1,100> {bindc_name = "msg", uniq_name = "_QFsub1Emsg"}
64+
%8:2 = hlfir.declare %7 typeparams %c100 {uniq_name = "_QFsub1Emsg"} : (!fir.ref<!fir.char<1,100>>, index) -> (!fir.ref<!fir.char<1,100>>, !fir.ref<!fir.char<1,100>>)
65+
%9 = fir.embox %8#1 : (!fir.ref<!fir.char<1,100>>) -> !fir.box<!fir.char<1,100>>
66+
%11 = fir.convert %4#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
67+
%16 = fir.convert %9 : (!fir.box<!fir.char<1,100>>) -> !fir.box<none>
68+
%13 = fir.cuda_allocate %11 : !fir.ref<!fir.box<none>> errmsg(%16 : !fir.box<none>) {cuda_attr = #fir.cuda<device>, hasStat} -> i32
69+
return
70+
}

0 commit comments

Comments
 (0)