4
4
#include < ATen/cuda/CUDAContext.h>
5
5
#include < c10/cuda/CUDAGuard.h>
6
6
7
+ #if defined(USE_ROCM)
8
+ #include < hip/amd_detail/amd_hip_bf16.h>
9
+ #include < hip/amd_detail/amd_hip_atomic.h>
10
+ #include < hip/amd_detail/hip_ldg.h>
11
+ #endif
12
+
7
13
namespace c10d {
8
14
namespace intra_node_comm {
9
15
@@ -17,7 +23,7 @@ static constexpr size_t kOneShotThreshBytes = 256 * 1024;
17
23
static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024 ;
18
24
19
25
#if defined(USE_ROCM)
20
- using __nv_bfloat162 = uint32_t ;
26
+ using __nv_bfloat162 = __hip_bfloat162 ;
21
27
#endif
22
28
23
29
struct __align__ (16 ) bf16x8 {
@@ -28,10 +34,7 @@ struct __align__(16) bf16x8 {
28
34
29
35
DEVICE_INLINE __nv_bfloat162
30
36
bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
31
- #if defined(USE_ROCM)
32
- CUDA_KERNEL_ASSERT (false );
33
- return 0 ;
34
- #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
37
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
35
38
CUDA_KERNEL_ASSERT (false );
36
39
__nv_bfloat162 res;
37
40
return res;
@@ -70,8 +73,12 @@ DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
70
73
*/
71
74
template <typename T>
72
75
DEVICE_INLINE void streamLoad128 (bf16x8& val, const T* addr) {
73
- #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
76
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
74
77
CUDA_KERNEL_ASSERT (false );
78
+ #elif defined(USE_ROCM)
79
+ ulonglong2 l_val = __ldg (reinterpret_cast <const ulonglong2 *>(addr));
80
+ reinterpret_cast <unsigned long long *>(&val)[0 ] = l_val.data [0 ];
81
+ reinterpret_cast <unsigned long long *>(&val)[1 ] = l_val.data [1 ];
75
82
#else
76
83
unsigned long long int low, high;
77
84
asm (" ld.global.nc.v2.u64 {%0, %1}, [%2];"
@@ -83,8 +90,13 @@ DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
83
90
}
84
91
85
92
__device__ inline void streamStore128 (at::BFloat16* addr, const bf16x8& val) {
86
- #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
93
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
87
94
CUDA_KERNEL_ASSERT (false );
95
+ #elif defined(USE_ROCM)
96
+ for (int i = 0 ; i < 8 ; i++)
97
+ {
98
+ addr[i] = reinterpret_cast <const at::BFloat16*>(&val)[i];
99
+ }
88
100
#else
89
101
unsigned long long int low, high;
90
102
low = reinterpret_cast <const unsigned long long int *>(&val)[0 ];
@@ -104,15 +116,16 @@ DEVICE_INLINE void store128(T* addr, const bf16x8& val) {
104
116
}
105
117
106
118
DEVICE_INLINE void releaseSignal (uint32_t * addr) {
107
- #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
119
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
108
120
CUDA_KERNEL_ASSERT (false );
109
121
#else
110
122
atomicAdd_system (addr, 1 );
123
+ __threadfence_system ();
111
124
#endif
112
125
}
113
126
114
127
DEVICE_INLINE void acquireSignal (uint32_t * addr) {
115
- #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
128
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
116
129
CUDA_KERNEL_ASSERT (false );
117
130
#else
118
131
volatile uint32_t * signal = addr;
@@ -484,7 +497,7 @@ static void getLaunchConfig(
484
497
}
485
498
486
499
bool isIntraNodeCommSupported () {
487
- #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
500
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
488
501
return false ;
489
502
#else
490
503
return true ;
0 commit comments