Skip to content

Commit e5cfd6d

Browse files
authored
misc: unused import cleanup (#1092)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Short unused import code cleanup ## 🔍 Related Issues N/A <!-- Link any related issues here --> ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues.
1 parent 8daa44f commit e5cfd6d

12 files changed

+22
-56
lines changed

benchmarks/bench_groupwise_gemm_fp8_blackwell.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,12 @@
1414
limitations under the License.
1515
"""
1616

17-
import pytest
1817
import torch
1918
import triton
2019
import triton.language as tl
2120
from triton.testing import do_bench
2221

23-
import flashinfer
24-
from flashinfer.gemm import gemm_fp8_nt_blockscaled, gemm_fp8_nt_groupwise
22+
from flashinfer.gemm import gemm_fp8_nt_groupwise
2523

2624

2725
@triton.jit

benchmarks/bench_groupwise_grouped_gemm_fp8_blackwell.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
limitations under the License.
1515
"""
1616

17-
import numpy as np
1817
import torch
1918
from triton.testing import do_bench
2019

benchmarks/bench_pad_ragged_tensor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import argparse
2-
from typing import cast
3-
41
import torch
52
from triton.testing import do_bench
63

benchmarks/bench_persistent_gemm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
import torch
32
import triton
43
from triton.testing import do_bench

benchmarks/bench_rope.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
$ python bench_rope.py
77
"""
88

9-
import math
10-
from typing import Any, Dict, List, Optional, Tuple, Union
9+
from typing import Optional, Tuple, Union
1110

1211
import torch
1312
import torch.nn as nn

flashinfer/jit/attention/pytorch.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,10 @@ def gen_batch_decode_mla_module(
243243
and dtype_kv == torch.float16
244244
and dtype_o == torch.float16
245245
):
246-
logger.info(f"Use tensor-core SM80 version of MLA decode kernel.")
246+
logger.info("Use tensor-core SM80 version of MLA decode kernel.")
247247
arc = "sm80"
248248
else:
249-
logger.info(f"Fall back to cuda-core version of MLA decode kernel.")
249+
logger.info("Fall back to cuda-core version of MLA decode kernel.")
250250
arc = "cuda_core"
251251

252252
uri = get_batch_decode_mla_uri(
@@ -424,7 +424,7 @@ def gen_single_decode_module(
424424
], # additional_scalar_names
425425
["double", "double", "double", "double"], # additional_scalar_dtypes
426426
f"DefaultAttention<false, {str(use_sliding_window).lower()}, {str(use_logits_soft_cap).lower()}, {str(pos_encoding_mode == 2).lower()}>", # variant_name
427-
f"#include<flashinfer/attention/variants.cuh>", # variant_decl
427+
"#include<flashinfer/attention/variants.cuh>", # variant_decl
428428
pos_encoding_mode=pos_encoding_mode,
429429
use_sliding_window=use_sliding_window,
430430
use_logits_soft_cap=use_logits_soft_cap,
@@ -473,22 +473,22 @@ def gen_single_prefill_module(
473473
]
474474
additional_scalar_dtypes = ["double", "double", "double", "double"]
475475
variant_name = f"DefaultAttention<use_custom_mask, {str(use_sliding_window).lower()}, {str(use_logits_soft_cap).lower()}, {str(pos_encoding_mode == 2).lower()}>"
476-
variant_decl = f"#include<flashinfer/attention/variants.cuh>"
476+
variant_decl = "#include<flashinfer/attention/variants.cuh>"
477477
else:
478478
if not fp8_enabled:
479479
additional_tensor_names = []
480480
additional_tensor_dtypes = []
481481
additional_scalar_names = ["logits_soft_cap", "sm_scale"]
482482
additional_scalar_dtypes = ["double", "double"]
483483
variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>"
484-
variant_decl = f"#include<flashinfer/attention/hopper/variants.cuh>"
484+
variant_decl = "#include<flashinfer/attention/hopper/variants.cuh>"
485485
else:
486486
additional_tensor_names = ["scale_q", "scale_k", "scale_v"]
487487
additional_tensor_dtypes = ["float", "float", "float"]
488488
additional_scalar_names = ["sm_scale"]
489489
additional_scalar_dtypes = ["double"]
490-
variant_name = f"DefaultFP8Attention"
491-
variant_decl = f"#include<flashinfer/attention/hopper/variants.cuh>"
490+
variant_name = "DefaultFP8Attention"
491+
variant_decl = "#include<flashinfer/attention/hopper/variants.cuh>"
492492

493493
return gen_customize_single_prefill_module(
494494
backend,
@@ -551,7 +551,7 @@ def gen_pod_module(
551551
additional_scalar_dtypes = ["float", "float", "float", "float"]
552552
variant_name_p = f"DefaultAttention<use_custom_mask_p, {str(use_sliding_window_p).lower()}, {str(use_logits_soft_cap_p).lower()}, {str(pos_encoding_mode_p == 2).lower()}>"
553553
variant_name_d = f"DefaultAttention<use_custom_mask_d, {str(use_sliding_window_d).lower()}, {str(use_logits_soft_cap_d).lower()}, {str(pos_encoding_mode_d == 2).lower()}>"
554-
variant_decl = f"#include<flashinfer/attention/variants.cuh>"
554+
variant_decl = "#include<flashinfer/attention/variants.cuh>"
555555

556556
return gen_customize_pod_module(
557557
uri,
@@ -717,7 +717,7 @@ def gen_batch_decode_module(
717717
], # additional_scalar_names
718718
["double", "double", "double", "double"], # additional_scalar_dtypes
719719
f"DefaultAttention<false, {str(use_sliding_window).lower()}, {str(use_logits_soft_cap).lower()}, {str(pos_encoding_mode == 2).lower()}>", # variant_name
720-
f"#include<flashinfer/attention/variants.cuh>", # variant_decl
720+
"#include<flashinfer/attention/variants.cuh>", # variant_decl
721721
pos_encoding_mode=pos_encoding_mode,
722722
use_sliding_window=use_sliding_window,
723723
use_logits_soft_cap=use_logits_soft_cap,
@@ -799,14 +799,14 @@ def gen_batch_prefill_module(
799799
]
800800
additional_scalar_dtypes = ["double", "double", "int64_t"]
801801
variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>"
802-
variant_decl = f"#include<flashinfer/attention/hopper/variants.cuh>"
802+
variant_decl = "#include<flashinfer/attention/hopper/variants.cuh>"
803803
else:
804804
additional_tensor_names = ["scale_q", "scale_k", "scale_v"]
805805
additional_tensor_dtypes = ["float", "float", "float"]
806806
additional_scalar_names = ["sm_scale"]
807807
additional_scalar_dtypes = ["double"]
808-
variant_name = f"DefaultFP8Attention"
809-
variant_decl = f"#include<flashinfer/attention/hopper/variants.cuh>"
808+
variant_name = "DefaultFP8Attention"
809+
variant_decl = "#include<flashinfer/attention/hopper/variants.cuh>"
810810

811811
return gen_customize_batch_prefill_module(
812812
backend,

flashinfer/mla.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,14 @@
1414
limitations under the License.
1515
"""
1616

17-
import functools
18-
from types import SimpleNamespace
19-
from typing import List, Literal, Optional, Tuple, Union, overload
17+
from typing import Literal, Optional, Tuple, Union, overload
2018

2119
import torch
2220

2321
from .jit import JitSpec
2422
from .jit import env as jit_env
2523
from .jit import gen_batch_mla_module, gen_jit_spec, sm100a_nvcc_flags
26-
from .utils import (
27-
MaskMode,
28-
_check_shape_dtype_device,
29-
determine_mla_backend,
30-
register_custom_op,
31-
register_fake_op,
32-
)
24+
from .utils import MaskMode, _check_shape_dtype_device, determine_mla_backend
3325

3426

3527
def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table):

flashinfer/pod.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,16 @@
1414
limitations under the License.
1515
"""
1616

17-
import functools
18-
import logging
1917
import math
2018
from types import SimpleNamespace
21-
from typing import Any, List, Literal, Optional, Tuple, Union, overload
19+
from typing import Any, List, Optional, Tuple, Union
2220

2321
import torch
2422

25-
from .decode import get_batch_decode_module
26-
from .jit import (
27-
gen_batch_decode_module,
28-
gen_batch_prefill_module,
29-
gen_customize_batch_prefill_module,
30-
gen_pod_module,
31-
gen_single_prefill_module,
32-
get_pod_uri,
33-
)
34-
from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens
23+
from .jit import gen_pod_module
24+
from .page import get_seq_lens
3525
from .prefill import get_batch_prefill_module
36-
from .quantization import packbits, segment_packbits
26+
from .quantization import packbits
3727
from .utils import (
3828
MaskMode,
3929
PosEncodingMode,
@@ -46,10 +36,6 @@
4636
_get_range_buf,
4737
_unpack_paged_kv_cache,
4838
canonicalize_torch_dtype,
49-
determine_attention_backend,
50-
is_float8,
51-
register_custom_op,
52-
register_fake_op,
5339
)
5440

5541
_pod_modules = {}

flashinfer/prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def get_fmha_module(
8484
use_logits_soft_cap,
8585
).build_and_load()
8686
else:
87-
raise ValueError(f"SM100A is not supported on this device")
87+
raise ValueError("SM100A is not supported on this device")
8888

8989

9090
def get_single_prefill_module(backend):

flashinfer/sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def plan(
352352
if (
353353
R * (num_qo_heads // num_kv_heads) < 4
354354
and mask_mode != MaskMode.CUSTOM.value
355-
and not q_data_type in [torch.float8_e4m3fn, torch.float8_e5m2]
355+
and q_data_type not in [torch.float8_e4m3fn, torch.float8_e5m2]
356356
):
357357
# If the operation is not compute-bound, we use the cuda-core implementation
358358
self._use_tensor_cores = False

flashinfer/triton/norm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from collections.abc import Mapping
21
from typing import Optional
32

43
import torch

flashinfer/triton/page.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
limitations under the License.
1515
"""
1616

17-
from typing import Optional, Tuple, Union
18-
19-
import torch
2017
import triton
2118
import triton.language as tl
2219

0 commit comments

Comments
 (0)