Skip to content

Commit 8daa44f

Browse files
authored
comm: refactor and initialize flashinfer.comm module (#1089)
1 parent 7f4d1db commit 8daa44f

File tree

13 files changed

+436
-1032
lines changed

13 files changed

+436
-1032
lines changed

.gitmodules

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
[submodule "3rdparty/googletest"]
55
path = 3rdparty/googletest
66
url = https://github.com/google/googletest.git
7-
[submodule "3rdparty/mscclpp"]
8-
path = 3rdparty/mscclpp
9-
url = https://github.com/microsoft/mscclpp.git
107
[submodule "3rdparty/cutlass"]
118
path = 3rdparty/cutlass
129
url = https://github.com/NVIDIA/cutlass.git

3rdparty/mscclpp

Lines changed: 0 additions & 1 deletion
This file was deleted.

CMakeLists.txt

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,6 @@ flashinfer_option(
4949
flashinfer_option(
5050
FLASHINFER_NORM
5151
"Whether to compile normalization kernel tests/benchmarks or not." OFF)
52-
flashinfer_option(
53-
FLASHINFER_DISTRIBUTED
54-
"Whether to compile distributed kernel tests/benchmarks or not." OFF)
5552
flashinfer_option(FLASHINFER_FASTDIV_TEST
5653
"Whether to compile fastdiv kernel tests or not." OFF)
5754
flashinfer_option(FLASHINFER_FASTDEQAUNT_TEST
@@ -83,11 +80,7 @@ if(FLASHINFER_PREFILL
8380
OR FLASHINFER_NORM)
8481
message(STATUS "NVBench and GoogleTest enabled")
8582
add_subdirectory(3rdparty/nvbench)
86-
if(FLASHINFER_DISTRIBUTED)
87-
add_subdirectory(3rdparty/mscclpp)
88-
else(FLASHINFER_DISTRIBUTED)
89-
add_subdirectory(3rdparty/googletest)
90-
endif(FLASHINFER_DISTRIBUTED)
83+
add_subdirectory(3rdparty/googletest)
9184
endif(
9285
FLASHINFER_PREFILL
9386
OR FLASHINFER_DECODE
@@ -457,29 +450,3 @@ if(FLASHINFER_FASTDEQUANT_TEST)
457450
test_fast_dequant PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
458451
target_link_libraries(test_fast_dequant PRIVATE gtest gtest_main)
459452
endif(FLASHINFER_FASTDEQUANT_TEST)
460-
461-
if(FLASHINFER_DISTRIBUTED)
462-
find_package(MPI REQUIRED)
463-
464-
message(STATUS "Compile sum all-reduce kernel tests.")
465-
file(GLOB_RECURSE TEST_DIST_SUM_ALL_REDUCE_SRCS
466-
${PROJECT_SOURCE_DIR}/src/test_sum_all_reduce.cu)
467-
add_executable(test_sum_all_reduce ${TEST_DIST_SUM_ALL_REDUCE_SRCS})
468-
target_include_directories(
469-
test_sum_all_reduce
470-
PRIVATE ${FLASHINFER_INCLUDE_DIR} 3rdparty/mscclpp/include
471-
3rdparty/spdlog/include)
472-
target_link_libraries(test_sum_all_reduce PRIVATE MPI::MPI_CXX mscclpp)
473-
target_compile_definitions(test_sum_all_reduce PRIVATE -DENABLE_MPI)
474-
475-
message(STATUS "Compile attention allreduce kernel tests.")
476-
file(GLOB_RECURSE TEST_DIST_ATTN_ALL_REDUCE_SRCS
477-
${PROJECT_SOURCE_DIR}/src/test_attn_all_reduce.cu)
478-
add_executable(test_attn_all_reduce ${TEST_DIST_ATTN_ALL_REDUCE_SRCS})
479-
target_include_directories(
480-
test_attn_all_reduce
481-
PRIVATE ${FLASHINFER_INCLUDE_DIR} 3rdparty/mscclpp/include
482-
3rdparty/spdlog/include)
483-
target_link_libraries(test_attn_all_reduce PRIVATE MPI::MPI_CXX mscclpp)
484-
target_compile_definitions(test_attn_all_reduce PRIVATE -DENABLE_MPI)
485-
endif(FLASHINFER_DISTRIBUTED)

cmake/config.cmake

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ set(FLASHINFER_NORM ON)
1919
set(FLASHINFER_FASTDIV_TEST ON)
2020
# Whether to compile fastdequant tests
2121
set(FLASHINFER_FASTDEQUANT_TEST ON)
22-
# Whether to compile distributed tests
23-
set(FLASHINFER_DISTRIBUTED ON)
2422
# The following configurations can impact the binary size of the generated
2523
# library
2624
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256 512)

csrc/custom_all_reduce.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// flashinfer: adapted from sglang + vllm code
22
// refer to: https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cu
3-
#include "flashinfer/distributed/custom_all_reduce.cuh"
3+
#include "flashinfer/comm/custom_all_reduce.cuh"
44
#include "pytorch_extension_utils.h"
55

66
// Fake pointer type, must match fptr_t type in ops.h.

flashinfer/aot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .activation import act_func_def_str, gen_act_and_mul_module
1212
from .cascade import gen_cascade_module
13-
from .custom_all_reduce import gen_comm_module
13+
from .comm import gen_comm_module
1414
from .gemm import gen_gemm_module, gen_gemm_sm90_module, gen_gemm_sm100_module
1515
from .jit import JitSpec, build_jit_specs
1616
from .jit import env as jit_env

0 commit comments

Comments
 (0)