Skip to content

Commit cd8229b

Browse files
authored
[flang][cuda] Support c_devptr in c_f_pointer intrinsic (#107470)
This is an extension of CUDA Fortran. The iso_c_binding intrinsic can accept a `TYPE(c_devptr)` as its first argument. This patch relax the semantic check to accept it and update the lowering to unwrap the cptr field from the c_devptr.
1 parent 7543d09 commit cd8229b

File tree

4 files changed

+60
-2
lines changed

4 files changed

+60
-2
lines changed

flang/include/flang/Optimizer/Dialect/FIRType.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ inline bool isa_builtin_cptr_type(mlir::Type t) {
139139
return false;
140140
}
141141

142+
/// Is `t` type(c_devptr)?
143+
inline bool isa_builtin_cdevptr_type(mlir::Type t) {
144+
if (auto recTy = mlir::dyn_cast_or_null<fir::RecordType>(t))
145+
return recTy.getName().ends_with("T__builtin_c_devptr");
146+
return false;
147+
}
148+
142149
/// Is `t` a FIR dialect aggregate type?
143150
inline bool isa_aggregate(mlir::Type t) {
144151
return mlir::isa<SequenceType, mlir::TupleType>(t) || fir::isa_derived(t);

flang/lib/Evaluate/intrinsics.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2811,8 +2811,10 @@ IntrinsicProcTable::Implementation::HandleC_F_Pointer(
28112811
if (auto type{expr->GetType()}) {
28122812
if (type->category() != TypeCategory::Derived ||
28132813
type->IsPolymorphic() ||
2814-
type->GetDerivedTypeSpec().typeSymbol().name() !=
2815-
"__builtin_c_ptr") {
2814+
(type->GetDerivedTypeSpec().typeSymbol().name() !=
2815+
"__builtin_c_ptr" &&
2816+
type->GetDerivedTypeSpec().typeSymbol().name() !=
2817+
"__builtin_c_devptr")) {
28162818
context.messages().Say(arguments[0]->sourceLocation(),
28172819
"CPTR= argument to C_F_POINTER() must be a C_PTR"_err_en_US);
28182820
}

flang/lib/Optimizer/Builder/FIRBuilder.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,24 @@ mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder,
15801580
mlir::Location loc,
15811581
mlir::Value cPtr) {
15821582
mlir::Type cPtrTy = fir::unwrapRefType(cPtr.getType());
1583+
if (fir::isa_builtin_cdevptr_type(cPtrTy)) {
1584+
// Unwrap c_ptr from c_devptr.
1585+
auto [addrFieldIndex, addrFieldTy] =
1586+
genCPtrOrCFunptrFieldIndex(builder, loc, cPtrTy);
1587+
mlir::Value cPtrCoor;
1588+
if (fir::isa_ref_type(cPtr.getType())) {
1589+
cPtrCoor = builder.create<fir::CoordinateOp>(
1590+
loc, builder.getRefType(addrFieldTy), cPtr, addrFieldIndex);
1591+
} else {
1592+
auto arrayAttr = builder.getArrayAttr(
1593+
{builder.getIntegerAttr(builder.getIndexType(), 0)});
1594+
cPtrCoor = builder.create<fir::ExtractValueOp>(loc, addrFieldTy, cPtr,
1595+
arrayAttr);
1596+
}
1597+
mlir::Value cptr = builder.create<fir::LoadOp>(loc, cPtrCoor);
1598+
return genCPtrOrCFunptrValue(builder, loc, cptr);
1599+
}
1600+
15831601
if (fir::isa_ref_type(cPtr.getType())) {
15841602
mlir::Value cPtrAddr =
15851603
fir::factory::genCPtrOrCFunptrAddr(builder, loc, cPtr, cPtrTy);

flang/test/Lower/CUDA/cuda-devptr.cuf

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,18 @@
22

33
! Test CUDA Fortran specific type
44

5+
module cudafct
6+
use __fortran_builtins, only : c_devptr => __builtin_c_devptr
7+
contains
8+
function c_devloc(x)
9+
use iso_c_binding, only: c_loc
10+
type(c_devptr) :: c_devloc
11+
!dir$ ignore_tkr (tkr) x
12+
real, target, device :: x
13+
c_devloc%cptr = c_loc(x)
14+
end function
15+
end
16+
517
subroutine sub1()
618
use iso_c_binding
719
use __fortran_builtins, only : c_devptr => __builtin_c_devptr
@@ -14,3 +26,22 @@ end
1426

1527
! CHECK-LABEL: func.func @_QPsub1()
1628
! CHECK-COUNT-2: %{{.*}} = fir.call @_FortranAioOutputDerivedType
29+
30+
subroutine sub2()
31+
use cudafct
32+
use iso_c_binding, only: c_f_pointer
33+
34+
real(4), device :: a(8, 10)
35+
real(4), device, pointer :: x(:)
36+
call c_f_pointer(c_devloc(a), x, (/80/))
37+
end
38+
39+
! CHECK-LABEL: func.func @_QPsub2()
40+
! CHECK: %[[X:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFsub2Ex"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
41+
! CHECK: %[[CPTR:.*]] = fir.field_index cptr, !fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>
42+
! CHECK: %[[CPTR_COORD:.*]] = fir.coordinate_of %{{.*}}#1, %[[CPTR]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>>, !fir.field) -> !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
43+
! CHECK: %[[CPTR_LOAD:.*]] = fir.load %[[CPTR_COORD]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
44+
! CHECK: %[[ADDRESS:.*]] = fir.extract_value %[[CPTR_LOAD]], [0 : index] : (!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> i64
45+
! CHECK: %[[ADDRESS_IDX:.*]] = fir.convert %[[ADDRESS]] : (i64) -> !fir.ptr<!fir.array<?xf32>>
46+
! CHECK: %[[EMBOX:.*]] = fir.embox %[[ADDRESS_IDX]](%{{.*}}) : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
47+
! CHECK: fir.store %[[EMBOX]] to %[[X]]#1 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>

0 commit comments

Comments
 (0)