Skip to content

FlashInfer Windows support #964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions csrc/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

#include "pytorch_extension_utils.h"

#ifdef _WIN32
#define M_SQRT1_2 0.707106781186547524401
#endif

using namespace flashinfer;

__device__ __forceinline__ float silu(const float& val) { return val / (1.0f + __expf(-val)); }
Expand Down
10 changes: 6 additions & 4 deletions csrc/batch_decode_config.inc
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,16 @@ using IdType = int32_t;
DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, DTypeQ, DTypeKV, [&] { \
using DTypeO = DTypeQ; \
using Params = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \
static constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \
return DISPATCH_head_dim(head_dim_qk, HEAD_DIM_QK, [&] { \
[[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \
[[maybe_unused]] static constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
static constexpr bool stat_use_slid_window = USE_SLIDING_WINDOW; \
static constexpr bool stat_use_logits_soft_cap = USE_LOGITS_SOFT_CAP; \
using AttentionVariant = \
DefaultAttention</*use_custom_mask=*/false, USE_SLIDING_WINDOW, \
USE_LOGITS_SOFT_CAP, /*use_alibi_bias=*/false>; \
DefaultAttention</*use_custom_mask=*/false, stat_use_slid_window, \
stat_use_logits_soft_cap, /*use_alibi_bias=*/false>; \
__VA_ARGS__(); \
return true; \
}); \
Expand Down
10 changes: 5 additions & 5 deletions csrc/batch_decode_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
using DTypeO = {{ dtype_o }};
using IdType = {{ idtype }};
constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }};
constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }};
static constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
static constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
static constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
static constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }};
static constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }};

struct Params {
using DTypeQ = DTypeQ;
Expand Down
14 changes: 8 additions & 6 deletions csrc/batch_prefill_config.inc
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,18 @@ using IdType = int32_t;
using DTypeO = DTypeQ; \
using RaggedParams = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
using PagedParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \
constexpr bool USE_FP16_QK_REDUCTION = false; \
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; \
static constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \
static constexpr bool USE_FP16_QK_REDUCTION = false; \
static constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; \
return DISPATCH_head_dim(head_dim_qk, HEAD_DIM_QK, [&] { \
[[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \
[[maybe_unused]] static constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
static constexpr bool stat_use_slid_window = USE_SLIDING_WINDOW; \
static constexpr bool stat_use_logits_soft_cap = USE_LOGITS_SOFT_CAP; \
using AttentionVariant = \
DefaultAttention</*use_custom_mask=*/use_custom_mask, USE_SLIDING_WINDOW, \
USE_LOGITS_SOFT_CAP, /*use_alibi_bias=*/false>; \
DefaultAttention</*use_custom_mask=*/use_custom_mask, stat_use_slid_window, \
stat_use_logits_soft_cap, /*use_alibi_bias=*/false>; \
__VA_ARGS__(); \
return true; \
}); \
Expand Down
14 changes: 7 additions & 7 deletions csrc/batch_prefill_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, RaggedParams, PagedParams, ...) \
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \
constexpr auto use_custom_mask = MASK_MODE == MaskMode::kCustom; \
static constexpr auto use_custom_mask = MASK_MODE == MaskMode::kCustom; \
using AttentionVariant = {{ variant_name }}; \
__VA_ARGS__(); \
})
Expand All @@ -23,12 +23,12 @@ using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
using DTypeO = {{ dtype_o }};
using IdType = {{ idtype }};
constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }};
constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }};
constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }};
static constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
static constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
static constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }};
static constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
static constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }};
static constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }};


struct RaggedParams {
Expand Down
43 changes: 43 additions & 0 deletions csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,49 @@
#include "aot_default_additional_params.h"
#include "pytorch_extension_utils.h"

// MSVC compiler only allows to link with a single
// PyInit_flashinfer_kernels defined per .pyd, define here

#ifdef _WIN32
#ifndef FLASHINFER_EXT_MODULE_INITED
#define FLASHINFER_EXT_MODULE_INITED

// To expand macros in #name
#define FLASHINFER_EXT_MODULE_INIT_EXPAND(name) FLASHINFER_EXT_MODULE_INIT(name)

/* Creates a dummy empty module that can be imported from Python.
The import from Python will load the .so consisting of the file
in this extension, so that the TORCH_LIBRARY_FRAGMENT static initializers
are run. */
#ifdef _WIN32
#define FLASHINFER_EXT_MODULE_INIT(name) \
extern "C" { \
PyObject* PyInit_##name(void) { \
static struct PyModuleDef module_def = { \
PyModuleDef_HEAD_INIT, \
#name, /* name of module */ \
NULL, /* module documentation, may be NULL */ \
-1, /* size of per-interpreter state of the module, \
or -1 if the module keeps state in global variables. */ \
NULL, /* methods */ \
NULL, /* slots */ \
NULL, /* traverse */ \
NULL, /* clear */ \
NULL, /* free */ \
}; \
return PyModule_Create(&module_def); \
} \
}
#endif

FLASHINFER_EXT_MODULE_INIT_EXPAND(TORCH_EXTENSION_NAME)

#undef FLASHINFER_EXT_MODULE_INIT
#undef FLASHINFER_EXT_MODULE_INIT_EXPAND

#endif
#endif

//========== activation ==========

void silu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl, int64_t cuda_stream);
Expand Down
43 changes: 43 additions & 0 deletions csrc/flashinfer_ops_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,49 @@
#include "aot_default_additional_params.h"
#include "pytorch_extension_utils.h"

// MSVC compiler only allows to link with a single
// PyInit_flashinfer_kernels defined per .pyd, define here

#ifdef _WIN32
#ifndef FLASHINFER_EXT_MODULE_INITED
#define FLASHINFER_EXT_MODULE_INITED

// To expand macros in #name
#define FLASHINFER_EXT_MODULE_INIT_EXPAND(name) FLASHINFER_EXT_MODULE_INIT(name)

/* Creates a dummy empty module that can be imported from Python.
The import from Python will load the .so consisting of the file
in this extension, so that the TORCH_LIBRARY_FRAGMENT static initializers
are run. */
#ifdef _WIN32
#define FLASHINFER_EXT_MODULE_INIT(name) \
extern "C" { \
PyObject* PyInit_##name(void) { \
static struct PyModuleDef module_def = { \
PyModuleDef_HEAD_INIT, \
#name, /* name of module */ \
NULL, /* module documentation, may be NULL */ \
-1, /* size of per-interpreter state of the module, \
or -1 if the module keeps state in global variables. */ \
NULL, /* methods */ \
NULL, /* slots */ \
NULL, /* traverse */ \
NULL, /* clear */ \
NULL, /* free */ \
}; \
return PyModule_Create(&module_def); \
} \
}
#endif

FLASHINFER_EXT_MODULE_INIT_EXPAND(TORCH_EXTENSION_NAME)

#undef FLASHINFER_EXT_MODULE_INIT
#undef FLASHINFER_EXT_MODULE_INIT_EXPAND

#endif
#endif

void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr,
at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride,
Expand Down
Loading