Skip to content

Commit 0c2f97c

Browse files
jataylopragupta
authored andcommitted
[ROCm] Intra-node all reduce initial implementation (#1435)
* Initial commit to port intra_node_comm to ROCm (cherry picked from commit 48d1c33) * gpt-fast running now with intra-node comm (cherry picked from commit 618c54e) --------- Co-authored-by: Prachi Gupta <[email protected]>
1 parent 5187ca9 commit 0c2f97c

File tree

3 files changed

+66
-15
lines changed

3 files changed

+66
-15
lines changed

caffe2/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,10 @@ if(USE_ROCM)
609609
append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_HIP_SRCS)
610610
if(NOT WIN32)
611611
append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_HIP_SRCS)
612+
set_source_files_properties(
613+
${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp
614+
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
615+
)
612616
endif()
613617
endif()
614618
# caffe2_nvrtc's stubs to driver APIs are useful for HIP.

torch/csrc/distributed/c10d/intra_node_comm.cpp

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
1919
#include <c10/cuda/driver_api.h>
2020
#include <nvml.h>
21+
#else
22+
#include <rocm_smi/rocm_smi.h>
2123
#endif
2224

2325
#include <cuda_runtime.h>
@@ -146,7 +148,26 @@ static NvlMesh getNvlMesh(const std::vector<std::string>& rankToBusId) {
146148
}
147149
return nvlMesh;
148150
#else
149-
return {};
151+
NvlMesh nvlMesh = {};
152+
const auto worldSize = rankToBusId.size();
153+
// For each device, loop over devices connected to it
154+
for (size_t idx = 0; idx < worldSize; ++idx) {
155+
for (size_t link = 0; link < kMaxDevices; ++link) {
156+
if(idx == link) continue;
157+
158+
bool conn = false;
159+
auto ret = rsmi_is_P2P_accessible(idx, link, &conn);
160+
if (ret != RSMI_STATUS_SUCCESS){
161+
LOG(ERROR) << "IntraNodeComm: getNvlMesh: rsmi_is_P2P_accessible returned error ret=" << ret;
162+
return {};
163+
}
164+
165+
if (conn){
166+
nvlMesh[idx][link] += 1;
167+
}
168+
}
169+
}
170+
return nvlMesh;
150171
#endif
151172
}
152173

@@ -272,7 +293,6 @@ bool IntraNodeComm::rendezvous() {
272293
if (isInitialized_) {
273294
return true;
274295
}
275-
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
276296
if (!isIntraNodeCommSupported() || worldSize_ < 2 ||
277297
worldSize_ > kMaxDevices) {
278298
return false;
@@ -289,12 +309,28 @@ bool IntraNodeComm::rendezvous() {
289309

290310
DevInfo devInfo{};
291311
gethostname(devInfo.hostname, sizeof(devInfo.hostname));
312+
313+
#if defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
314+
auto ret = rsmi_init(0);
315+
if (ret != RSMI_STATUS_SUCCESS) {
316+
LOG(ERROR) << "IntraNodeComm:: rendezvous failed in rsmi_init, ret=" << ret;
317+
return false;
318+
}
319+
#endif
320+
292321
cudaDeviceProp prop{};
293322
AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx));
323+
324+
#if defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
325+
auto pci_format = "%08X:%02X:%02X.0";
326+
#else
327+
auto pci_format = NVML_DEVICE_PCI_BUS_ID_FMT;
328+
#endif
329+
294330
snprintf(
295331
devInfo.busId,
296332
sizeof(devInfo.busId),
297-
NVML_DEVICE_PCI_BUS_ID_FMT,
333+
pci_format,
298334
prop.pciDomainID,
299335
prop.pciBusID,
300336
prop.pciDeviceID);
@@ -344,8 +380,6 @@ bool IntraNodeComm::rendezvous() {
344380
buffersDev_ = symmetricMemory_->get_buffer_ptrs_dev();
345381
topoInfo_ = topoInfo;
346382
return true;
347-
#endif
348-
return false;
349383
}
350384

351385
} // namespace c10d::intra_node_comm

torch/csrc/distributed/c10d/intra_node_comm.cu

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
#include <ATen/cuda/CUDAContext.h>
55
#include <c10/cuda/CUDAGuard.h>
66

7+
#if defined(USE_ROCM)
8+
#include <hip/amd_detail/amd_hip_bf16.h>
9+
#include <hip/amd_detail/amd_hip_atomic.h>
10+
#include <hip/amd_detail/hip_ldg.h>
11+
#endif
12+
713
namespace c10d {
814
namespace intra_node_comm {
915

@@ -17,7 +23,7 @@ static constexpr size_t kOneShotThreshBytes = 256 * 1024;
1723
static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024;
1824

1925
#if defined(USE_ROCM)
20-
using __nv_bfloat162 = uint32_t;
26+
using __nv_bfloat162 = __hip_bfloat162;
2127
#endif
2228

2329
struct __align__(16) bf16x8 {
@@ -28,10 +34,7 @@ struct __align__(16) bf16x8 {
2834

2935
DEVICE_INLINE __nv_bfloat162
3036
bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
31-
#if defined(USE_ROCM)
32-
CUDA_KERNEL_ASSERT(false);
33-
return 0;
34-
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
37+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
3538
CUDA_KERNEL_ASSERT(false);
3639
__nv_bfloat162 res;
3740
return res;
@@ -70,8 +73,12 @@ DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
7073
*/
7174
template <typename T>
7275
DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
73-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
76+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
7477
CUDA_KERNEL_ASSERT(false);
78+
#elif defined(USE_ROCM)
79+
ulonglong2 l_val = __ldg(reinterpret_cast<const ulonglong2*>(addr));
80+
reinterpret_cast<unsigned long long*>(&val)[0] = l_val.data[0];
81+
reinterpret_cast<unsigned long long*>(&val)[1] = l_val.data[1];
7582
#else
7683
unsigned long long int low, high;
7784
asm("ld.global.nc.v2.u64 {%0, %1}, [%2];"
@@ -83,8 +90,13 @@ DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
8390
}
8491

8592
__device__ inline void streamStore128(at::BFloat16* addr, const bf16x8& val) {
86-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
93+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
8794
CUDA_KERNEL_ASSERT(false);
95+
#elif defined(USE_ROCM)
96+
for (int i = 0; i < 8; i++)
97+
{
98+
addr[i] = reinterpret_cast<const at::BFloat16*>(&val)[i];
99+
}
88100
#else
89101
unsigned long long int low, high;
90102
low = reinterpret_cast<const unsigned long long int*>(&val)[0];
@@ -104,15 +116,16 @@ DEVICE_INLINE void store128(T* addr, const bf16x8& val) {
104116
}
105117

106118
DEVICE_INLINE void releaseSignal(uint32_t* addr) {
107-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
119+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
108120
CUDA_KERNEL_ASSERT(false);
109121
#else
110122
atomicAdd_system(addr, 1);
123+
__threadfence_system();
111124
#endif
112125
}
113126

114127
DEVICE_INLINE void acquireSignal(uint32_t* addr) {
115-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
128+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
116129
CUDA_KERNEL_ASSERT(false);
117130
#else
118131
volatile uint32_t* signal = addr;
@@ -473,7 +486,7 @@ static void getLaunchConfig(
473486
}
474487

475488
bool isIntraNodeCommSupported() {
476-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
489+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
477490
return false;
478491
#else
479492
return true;

0 commit comments

Comments
 (0)