Skip to content

Commit 031b577

Browse files
jataylopragupta
authored andcommitted
[ROCm] Intra-node all reduce initial implementation (#1435)
* Initial commit to port intra_node_comm to ROCm (cherry picked from commit 48d1c33) * gpt-fast running now with intra-node comm (cherry picked from commit 618c54e) --------- Co-authored-by: Prachi Gupta <[email protected]>
1 parent f8d9e0e commit 031b577

File tree

3 files changed

+66
-15
lines changed

3 files changed

+66
-15
lines changed

caffe2/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,10 @@ if(USE_ROCM)
601601
append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_HIP_SRCS)
602602
if(NOT WIN32)
603603
append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_HIP_SRCS)
604+
set_source_files_properties(
605+
${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp
606+
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
607+
)
604608
endif()
605609
endif()
606610
# caffe2_nvrtc's stubs to driver APIs are useful for HIP.

torch/csrc/distributed/c10d/intra_node_comm.cpp

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
1919
#include <c10/cuda/driver_api.h>
2020
#include <nvml.h>
21+
#else
22+
#include <rocm_smi/rocm_smi.h>
2123
#endif
2224

2325
#include <cuda_runtime.h>
@@ -148,7 +150,26 @@ static NvlMesh getNvlMesh(const std::vector<std::string>& rankToBusId) {
148150
}
149151
return nvlMesh;
150152
#else
151-
return {};
153+
NvlMesh nvlMesh = {};
154+
const auto worldSize = rankToBusId.size();
155+
// For each device, loop over devices connected to it
156+
for (size_t idx = 0; idx < worldSize; ++idx) {
157+
for (size_t link = 0; link < kMaxDevices; ++link) {
158+
if(idx == link) continue;
159+
160+
bool conn = false;
161+
auto ret = rsmi_is_P2P_accessible(idx, link, &conn);
162+
if (ret != RSMI_STATUS_SUCCESS){
163+
LOG(ERROR) << "IntraNodeComm: getNvlMesh: rsmi_is_P2P_accessible returned error ret=" << ret;
164+
return {};
165+
}
166+
167+
if (conn){
168+
nvlMesh[idx][link] += 1;
169+
}
170+
}
171+
}
172+
return nvlMesh;
152173
#endif
153174
}
154175

@@ -274,7 +295,6 @@ bool IntraNodeComm::rendezvous() {
274295
if (isInitialized_) {
275296
return true;
276297
}
277-
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
278298
if (!isIntraNodeCommSupported() || worldSize_ < 2 ||
279299
worldSize_ > kMaxDevices) {
280300
return false;
@@ -291,12 +311,28 @@ bool IntraNodeComm::rendezvous() {
291311

292312
DevInfo devInfo{};
293313
gethostname(devInfo.hostname, sizeof(devInfo.hostname));
314+
315+
#if defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
316+
auto ret = rsmi_init(0);
317+
if (ret != RSMI_STATUS_SUCCESS) {
318+
LOG(ERROR) << "IntraNodeComm:: rendezvous failed in rsmi_init, ret=" << ret;
319+
return false;
320+
}
321+
#endif
322+
294323
cudaDeviceProp prop{};
295324
AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx_));
325+
326+
#if defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
327+
auto pci_format = "%08X:%02X:%02X.0";
328+
#else
329+
auto pci_format = NVML_DEVICE_PCI_BUS_ID_FMT;
330+
#endif
331+
296332
snprintf(
297333
devInfo.busId,
298334
sizeof(devInfo.busId),
299-
NVML_DEVICE_PCI_BUS_ID_FMT,
335+
pci_format,
300336
prop.pciDomainID,
301337
prop.pciBusID,
302338
prop.pciDeviceID);
@@ -346,8 +382,6 @@ bool IntraNodeComm::rendezvous() {
346382
buffersDev_ = symmetricMemory_->get_buffer_ptrs_dev();
347383
topoInfo_ = topoInfo;
348384
return true;
349-
#endif
350-
return false;
351385
}
352386

353387
} // namespace c10d::intra_node_comm

torch/csrc/distributed/c10d/intra_node_comm.cu

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
#include <ATen/cuda/CUDAContext.h>
55
#include <c10/cuda/CUDAGuard.h>
66

7+
#if defined(USE_ROCM)
8+
#include <hip/amd_detail/amd_hip_bf16.h>
9+
#include <hip/amd_detail/amd_hip_atomic.h>
10+
#include <hip/amd_detail/hip_ldg.h>
11+
#endif
12+
713
namespace c10d {
814
namespace intra_node_comm {
915

@@ -17,7 +23,7 @@ static constexpr size_t kOneShotThreshBytes = 256 * 1024;
1723
static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024;
1824

1925
#if defined(USE_ROCM)
20-
using __nv_bfloat162 = uint32_t;
26+
using __nv_bfloat162 = __hip_bfloat162;
2127
#endif
2228

2329
struct __align__(16) bf16x8 {
@@ -28,10 +34,7 @@ struct __align__(16) bf16x8 {
2834

2935
DEVICE_INLINE __nv_bfloat162
3036
bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
31-
#if defined(USE_ROCM)
32-
CUDA_KERNEL_ASSERT(false);
33-
return 0;
34-
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
37+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
3538
CUDA_KERNEL_ASSERT(false);
3639
__nv_bfloat162 res;
3740
return res;
@@ -70,8 +73,12 @@ DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
7073
*/
7174
template <typename T>
7275
DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
73-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
76+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
7477
CUDA_KERNEL_ASSERT(false);
78+
#elif defined(USE_ROCM)
79+
ulonglong2 l_val = __ldg(reinterpret_cast<const ulonglong2*>(addr));
80+
reinterpret_cast<unsigned long long*>(&val)[0] = l_val.data[0];
81+
reinterpret_cast<unsigned long long*>(&val)[1] = l_val.data[1];
7582
#else
7683
unsigned long long int low, high;
7784
asm("ld.global.nc.v2.u64 {%0, %1}, [%2];"
@@ -83,8 +90,13 @@ DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
8390
}
8491

8592
__device__ inline void streamStore128(at::BFloat16* addr, const bf16x8& val) {
86-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
93+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
8794
CUDA_KERNEL_ASSERT(false);
95+
#elif defined(USE_ROCM)
96+
for (int i = 0; i < 8; i++)
97+
{
98+
addr[i] = reinterpret_cast<const at::BFloat16*>(&val)[i];
99+
}
88100
#else
89101
unsigned long long int low, high;
90102
low = reinterpret_cast<const unsigned long long int*>(&val)[0];
@@ -104,15 +116,16 @@ DEVICE_INLINE void store128(T* addr, const bf16x8& val) {
104116
}
105117

106118
DEVICE_INLINE void releaseSignal(uint32_t* addr) {
107-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
119+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
108120
CUDA_KERNEL_ASSERT(false);
109121
#else
110122
atomicAdd_system(addr, 1);
123+
__threadfence_system();
111124
#endif
112125
}
113126

114127
DEVICE_INLINE void acquireSignal(uint32_t* addr) {
115-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
128+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
116129
CUDA_KERNEL_ASSERT(false);
117130
#else
118131
volatile uint32_t* signal = addr;
@@ -484,7 +497,7 @@ static void getLaunchConfig(
484497
}
485498

486499
bool isIntraNodeCommSupported() {
487-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
500+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
488501
return false;
489502
#else
490503
return true;

0 commit comments

Comments
 (0)