Skip to content

Commit 48d1c33

Browse files
committed
Initial commit to port intra_node_comm to ROCm
1 parent 6193b40 commit 48d1c33

File tree

3 files changed

+65
-15
lines changed

3 files changed

+65
-15
lines changed

caffe2/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,10 @@ if(USE_ROCM)
722722
append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_HIP_SRCS)
723723
if(NOT WIN32)
724724
append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_HIP_SRCS)
725+
set_source_files_properties(
726+
${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp
727+
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
728+
)
725729
endif()
726730
endif()
727731
# caffe2_nvrtc's stubs to driver APIs are useful for HIP.

torch/csrc/distributed/c10d/intra_node_comm.cpp

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
1919
#include <c10/cuda/driver_api.h>
2020
#include <nvml.h>
21+
#else
22+
#include <rocm_smi/rocm_smi.h>
2123
#endif
2224

2325
#include <cuda_runtime.h>
@@ -146,7 +148,26 @@ static NvlMesh getNvlMesh(const std::vector<std::string>& rankToBusId) {
146148
}
147149
return nvlMesh;
148150
#else
149-
return {};
151+
NvlMesh nvlMesh = {};
152+
const auto worldSize = rankToBusId.size();
153+
// For each device, loop over devices connected to it
154+
for (size_t idx = 0; idx < worldSize; ++idx) {
155+
for (size_t link = 0; link < kMaxDevices; ++link) {
156+
if(idx == link) continue;
157+
158+
bool conn = false;
159+
auto ret = rsmi_is_P2P_accessible(idx, link, &conn);
160+
if (ret != RSMI_STATUS_SUCCESS){
161+
LOG(ERROR) << "IntraNodeComm: getNvlMesh: rsmi_is_P2P_accessible returned error ret=" << ret;
162+
return {};
163+
}
164+
165+
if (conn){
166+
nvlMesh[idx][link] += 1;
167+
}
168+
}
169+
}
170+
return nvlMesh;
150171
#endif
151172
}
152173

@@ -288,7 +309,6 @@ bool IntraNodeComm::rendezvous() {
288309
if (isInitialized_) {
289310
return true;
290311
}
291-
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
292312
if (!isIntraNodeCommSupported() || !isEnabled() || worldSize_ < 2 ||
293313
worldSize_ > kMaxDevices) {
294314
return false;
@@ -305,12 +325,28 @@ bool IntraNodeComm::rendezvous() {
305325

306326
DevInfo devInfo{};
307327
gethostname(devInfo.hostname, sizeof(devInfo.hostname));
328+
329+
#if defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
330+
auto ret = rsmi_init(0);
331+
if (ret != RSMI_STATUS_SUCCESS) {
332+
LOG(ERROR) << "IntraNodeComm:: rendezvous failed in rsmi_init, ret=" << ret;
333+
return nullptr;
334+
}
335+
#endif
336+
308337
cudaDeviceProp prop{};
309338
AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx));
339+
340+
#if defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
341+
auto pci_format = "%08X:%02X:%02X.0";
342+
#else
343+
auto pci_format = NVML_DEVICE_PCI_BUS_ID_FMT;
344+
#endif
345+
310346
snprintf(
311347
devInfo.busId,
312348
sizeof(devInfo.busId),
313-
NVML_DEVICE_PCI_BUS_ID_FMT,
349+
pci_format,
314350
prop.pciDomainID,
315351
prop.pciBusID,
316352
prop.pciDeviceID);
@@ -424,8 +460,6 @@ bool IntraNodeComm::rendezvous() {
424460
buffersDev_ = buffersDev;
425461
topoInfo_ = topoInfo;
426462
return true;
427-
#endif
428-
return false;
429463
}
430464

431465
} // namespace c10d::intra_node_comm

torch/csrc/distributed/c10d/intra_node_comm.cu

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
#include <ATen/cuda/CUDAContext.h>
55
#include <c10/cuda/CUDAGuard.h>
66

7+
#if defined(USE_ROCM)
8+
#include <hip/amd_detail/amd_hip_bf16.h>
9+
#include <hip/amd_detail/amd_hip_atomic.h>
10+
#include <hip/amd_detail/hip_ldg.h>
11+
#endif
12+
713
namespace c10d {
814
namespace intra_node_comm {
915

@@ -17,7 +23,7 @@ static constexpr size_t kOneShotThreshBytes = 256 * 1024;
1723
static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024;
1824

1925
#if defined(USE_ROCM)
20-
using __nv_bfloat162 = uint32_t;
26+
using __nv_bfloat162 = __hip_bfloat162;
2127
#endif
2228

2329
struct __align__(16) bf16x8 {
@@ -28,10 +34,7 @@ struct __align__(16) bf16x8 {
2834

2935
DEVICE_INLINE __nv_bfloat162
3036
bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
31-
#if defined(USE_ROCM)
32-
CUDA_KERNEL_ASSERT(false);
33-
return 0;
34-
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
37+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
3538
CUDA_KERNEL_ASSERT(false);
3639
__nv_bfloat162 res;
3740
return res;
@@ -70,8 +73,12 @@ DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
7073
*/
7174
template <typename T>
7275
DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
73-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
76+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
7477
CUDA_KERNEL_ASSERT(false);
78+
#elif defined(USE_ROCM)
79+
ulonglong2 l_val = __ldg(reinterpret_cast<const ulonglong2*>(addr));
80+
reinterpret_cast<unsigned long long*>(&val)[0] = l_val.data[0];
81+
reinterpret_cast<unsigned long long*>(&val)[1] = l_val.data[1];
7582
#else
7683
unsigned long long int low, high;
7784
asm("ld.global.nc.v2.u64 {%0, %1}, [%2];"
@@ -83,8 +90,13 @@ DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) {
8390
}
8491

8592
__device__ inline void streamStore128(at::BFloat16* addr, const bf16x8& val) {
86-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
93+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
8794
CUDA_KERNEL_ASSERT(false);
95+
#elif defined(USE_ROCM)
96+
for (int i = 0; i < 8; i++)
97+
{
98+
addr[i] = reinterpret_cast<const at::BFloat16*>(&val)[i];
99+
}
88100
#else
89101
unsigned long long int low, high;
90102
low = reinterpret_cast<const unsigned long long int*>(&val)[0];
@@ -104,15 +116,15 @@ DEVICE_INLINE void store128(T* addr, const bf16x8& val) {
104116
}
105117

106118
DEVICE_INLINE void releaseSignal(uint32_t* addr) {
107-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
119+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
108120
CUDA_KERNEL_ASSERT(false);
109121
#else
110122
atomicAdd_system(addr, 1);
111123
#endif
112124
}
113125

114126
DEVICE_INLINE void acquireSignal(uint32_t* addr) {
115-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
127+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
116128
CUDA_KERNEL_ASSERT(false);
117129
#else
118130
volatile uint32_t* signal = addr;
@@ -471,7 +483,7 @@ static void getLaunchConfig(
471483
}
472484

473485
bool isIntraNodeCommSupported() {
474-
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
486+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
475487
return false;
476488
#else
477489
return true;

0 commit comments

Comments
 (0)