9
9
#include " flang/Runtime/CUDA/memory.h"
10
10
#include " gtest/gtest.h"
11
11
#include " ../../../runtime/terminator.h"
12
+ #include " ../tools.h"
12
13
#include " flang/Common/Fortran.h"
14
+ #include " flang/Runtime/CUDA/allocator.h"
13
15
#include " flang/Runtime/CUDA/common.h"
16
+ #include " flang/Runtime/CUDA/descriptor.h"
17
+ #include " flang/Runtime/allocatable.h"
18
+ #include " flang/Runtime/allocator-registry.h"
14
19
15
20
#include " cuda_runtime.h"
16
21
22
+ using namespace Fortran ::runtime;
17
23
using namespace Fortran ::runtime::cuda;
18
24
19
25
TEST (MemoryCUFTest, SimpleAllocTramsferFree) {
@@ -29,3 +35,37 @@ TEST(MemoryCUFTest, SimpleAllocTramsferFree) {
29
35
EXPECT_EQ (42 , host);
30
36
RTNAME (CUFMemFree)((void *)dev, kMemTypeDevice , __FILE__, __LINE__);
31
37
}
38
+
39
+ static OwningPtr<Descriptor> createAllocatable (
40
+ Fortran::common::TypeCategory tc, int kind, int rank = 1 ) {
41
+ return Descriptor::Create (TypeCode{tc, kind}, kind, nullptr , rank, nullptr ,
42
+ CFI_attribute_allocatable);
43
+ }
44
+
45
+ TEST (MemoryCUFTest, CUFDataTransferDescDesc) {
46
+ using Fortran::common::TypeCategory;
47
+ RTNAME (CUFRegisterAllocator)();
48
+ // INTEGER(4), DEVICE, ALLOCATABLE :: a(:)
49
+ auto dev{createAllocatable (TypeCategory::Integer, 4 )};
50
+ dev->SetAllocIdx (kDeviceAllocatorPos );
51
+ EXPECT_EQ ((int )kDeviceAllocatorPos , dev->GetAllocIdx ());
52
+ RTNAME (AllocatableSetBounds)(*dev, 0 , 1 , 10 );
53
+ RTNAME (AllocatableAllocate)
54
+ (*dev, /* hasStat=*/ false , /* errMsg=*/ nullptr , __FILE__, __LINE__);
55
+ EXPECT_TRUE (dev->IsAllocated ());
56
+
57
+ // Create temp array to transfer to device.
58
+ auto x{MakeArray<TypeCategory::Integer, 4 >(std::vector<int >{10 },
59
+ std::vector<int32_t >{0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 })};
60
+ RTNAME (CUFDataTransferDescDesc)(dev.get (), x.get (), kHostToDevice , __FILE__, __LINE__);
61
+
62
+ // Retrieve data from device.
63
+ auto host{MakeArray<TypeCategory::Integer, 4 >(std::vector<int >{10 },
64
+ std::vector<int32_t >{0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 })};
65
+ RTNAME (CUFDataTransferDescDesc)
66
+ (host.get (), dev.get (), kDeviceToHost , __FILE__, __LINE__);
67
+
68
+ for (unsigned i = 0 ; i < 10 ; ++i) {
69
+ EXPECT_EQ (*host->ZeroBasedIndexedElement <std::int32_t >(i), (std::int32_t )i);
70
+ }
71
+ }
0 commit comments