Skip to content

Commit b4f9162

Browse files
Jack-Khuumalfet
authored andcommitted
Run the linter on most of the repo (#895)
* Initial linter run * undoing changes to unsupported files
1 parent ea67b01 commit b4f9162

File tree

12 files changed

+67
-67
lines changed

12 files changed

+67
-67
lines changed

.ci/scripts/extract-sequence.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
22

3+
34
def print_until_equals(filename):
45
output = False
56
past_output = False

build/builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@
1010
from dataclasses import dataclass
1111
from pathlib import Path
1212
from typing import Any, Dict, Optional, Union
13-
from utils.measure_time import measure_time
1413

1514
import torch
1615
import torch._dynamo.config
1716
import torch._inductor.config
1817

1918
from config.model_config import resolve_model_config
19+
from distributed import init_distributed, ParallelDims, parallelize_llama
2020
from quantization.quantize import quantize_model
21+
from utils.measure_time import measure_time
2122

2223
from build.model import Transformer
2324
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
24-
from distributed import parallelize_llama, ParallelDims, init_distributed
2525

2626

2727
@dataclass
@@ -287,6 +287,7 @@ def _init_model_on_meta_device(builder_args):
287287
else:
288288
return Transformer.from_name(builder_args.checkpoint_path.parent.name)
289289

290+
290291
def _load_model_gguf(builder_args, only_config=False):
291292
assert builder_args.gguf_path
292293
if builder_args.gguf_kwargs is None:

build/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def name_to_dtype(name, device):
136136
import platform
137137

138138
if platform.processor() == "arm":
139-
device=get_device_str(device)
139+
device = get_device_str(device)
140140
# ARM CPU is faster with float16, MPS with bf16 if supported
141141
if device == "cpu" or int(platform.mac_ver()[0].split(".")[0]) < 14:
142142
return torch.float16

cli.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from build.utils import allowable_dtype_names, allowable_params_table, get_device_str
1616
from download import download_and_convert, is_model_downloaded
1717

18-
logging.basicConfig(level=logging.INFO,format="%(message)s")
18+
logging.basicConfig(level=logging.INFO, format="%(message)s")
1919
logger = logging.getLogger(__name__)
2020

2121
default_device = os.getenv("TORCHCHAT_DEVICE", "fast")
@@ -24,18 +24,19 @@
2424
).expanduser()
2525

2626

27-
# Subcommands related to downloading and managing model artifacts
27+
# Subcommands related to downloading and managing model artifacts
2828
INVENTORY_VERBS = ["download", "list", "remove", "where"]
2929

3030
# List of all supported subcommands in torchchat
3131
KNOWN_VERBS = ["chat", "browser", "generate", "eval", "export"] + INVENTORY_VERBS
3232

33+
3334
# Handle CLI arguments that are common to a majority of subcommands.
3435
def check_args(args, verb: str) -> None:
3536
# Handle model download. Skip this for download, since it has slightly
3637
# different semantics.
3738
if (
38-
verb not in INVENTORY_VERBS
39+
verb not in INVENTORY_VERBS
3940
and args.model
4041
and not is_model_downloaded(args.model, args.model_directory)
4142
):
@@ -47,11 +48,11 @@ def add_arguments_for_verb(parser, verb: str) -> None:
4748
# A model can be specified using a positional model name or HuggingFace
4849
# path. Alternatively, the model can be specified via --gguf-path or via
4950
# an explicit --checkpoint-dir, --checkpoint-path, or --tokenizer-path.
50-
51+
5152
if verb in INVENTORY_VERBS:
5253
_configure_artifact_inventory_args(parser, verb)
5354
_add_cli_metadata_args(parser)
54-
return
55+
return
5556

5657
parser.add_argument(
5758
"model",
@@ -164,10 +165,10 @@ def add_arguments_for_verb(parser, verb: str) -> None:
164165
choices=["fast", "cpu", "cuda", "mps"],
165166
help="Hardware device to use. Options: cpu, cuda, mps",
166167
)
167-
168+
168169
if verb == "eval":
169170
_add_evaluation_args(parser)
170-
171+
171172
parser.add_argument(
172173
"--hf-token",
173174
type=str,
@@ -350,7 +351,7 @@ def arg_init(args):
350351
)
351352

352353
if sys.version_info.major != 3 or sys.version_info.minor < 10:
353-
raise RuntimeError("Please use Python 3.10 or later.")
354+
raise RuntimeError("Please use Python 3.10 or later.")
354355

355356
if hasattr(args, "quantize") and Path(args.quantize).is_file():
356357
with open(args.quantize, "r") as f:

download.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@ def _download_hf_snapshot(
2626
from requests.exceptions import HTTPError
2727

2828
# Download and store the HF model artifacts.
29-
print(
30-
f"Downloading {model_config.name} from HuggingFace...",
31-
file=sys.stderr
32-
)
29+
print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr)
3330
try:
3431
snapshot_download(
3532
model_config.distribution_path,
@@ -54,10 +51,7 @@ def _download_hf_snapshot(
5451
raise e
5552

5653
# Convert the model to the torchchat format.
57-
print(
58-
f"Converting {model_config.name} to torchchat format...",
59-
file=sys.stderr
60-
)
54+
print(f"Converting {model_config.name} to torchchat format...", file=sys.stderr)
6155
convert_hf_checkpoint(
6256
model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True
6357
)
@@ -177,6 +171,7 @@ def remove_main(args) -> None:
177171
shutil.rmtree(model_dir)
178172
print("Done.")
179173

174+
180175
# Subcommand to print downloaded model artifacts directory.
181176
# Asking for location will/should trigger download of model if not available.
182177
def where_main(args) -> None:
@@ -196,6 +191,7 @@ def where_main(args) -> None:
196191
print(str(os.path.abspath(model_dir)))
197192
exit(0)
198193

194+
199195
# Subcommand to download model artifacts.
200196
def download_main(args) -> None:
201197
download_and_convert(args.model, args.model_directory, args.hf_token)

eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch
1111
import torch._dynamo.config
1212
import torch._inductor.config
13-
from utils.measure_time import measure_time
1413
from build.builder import (
1514
_initialize_model,
1615
_initialize_tokenizer,
@@ -22,6 +21,7 @@
2221
from build.utils import set_precision
2322
from cli import add_arguments_for_verb, arg_init
2423
from generate import encode_tokens, model_forward
24+
from utils.measure_time import measure_time
2525

2626
torch._dynamo.config.automatic_dynamic_shapes = True
2727
torch._inductor.config.triton.unique_kernel_names = True

generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import argparse
77
import itertools
88
import logging
9+
import os
910
import sys
1011
import time
11-
import os
1212
from dataclasses import dataclass
1313
from pathlib import Path
1414
from typing import List, Optional, Tuple

runner/run.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -196,16 +196,16 @@ float* forward(Transformer* transformer, int token, int pos) {
196196
torch::Tensor token_tensor =
197197
torch::from_blob(token_buffer, {1, 1}, torch::kLong);
198198
torch::Tensor pos_tensor = torch::from_blob(pos_buffer, {1}, torch::kLong);
199-
std::vector<torch::Tensor> inputs{token_tensor.to(aoti_device), pos_tensor.to(aoti_device)};
199+
std::vector<torch::Tensor> inputs{
200+
token_tensor.to(aoti_device), pos_tensor.to(aoti_device)};
200201

201202
torch::Tensor result = transformer->runner->run(inputs)[0]
202203
.to(torch::dtype(torch::kFloat32))
203204
.to(torch::kCPU);
204205
auto logits = result[0].data_ptr();
205206
#else // __ET_MODEL__
206207
ManagedTensor pos_managed(pos_buffer, {1}, ScalarType::Long);
207-
ManagedTensor tokens_managed(
208-
token_buffer, {1, 1}, ScalarType::Long);
208+
ManagedTensor tokens_managed(token_buffer, {1, 1}, ScalarType::Long);
209209
std::vector<EValue> inputs;
210210
auto tmp1 = EValue(tokens_managed.get_aliasing_tensor());
211211
auto tmp2 = EValue(pos_managed.get_aliasing_tensor());
@@ -811,7 +811,9 @@ void error_usage() {
811811
" -v <int> (optional) vocab size, default is model-specific.\n");
812812
fprintf(
813813
stderr, " -l <int> (optional) llama version (2 or 3), default 2.\n");
814-
fprintf(stderr, " -d <string> (optional) device(CUDA or CPU) model was exported for\n");
814+
fprintf(
815+
stderr,
816+
" -d <string> (optional) device(CUDA or CPU) model was exported for\n");
815817
exit(EXIT_FAILURE);
816818
}
817819

@@ -884,16 +886,16 @@ int main(int argc, char* argv[]) {
884886
#ifdef __AOTI_MODEL__
885887
} else if (argv[i][1] == 'd') {
886888
#ifdef USE_CUDA
887-
if (strcasecmp(argv[i + 1], "CUDA") == 0) {
888-
aoti_device = torch::Device(torch::kCUDA);
889-
} else
889+
if (strcasecmp(argv[i + 1], "CUDA") == 0) {
890+
aoti_device = torch::Device(torch::kCUDA);
891+
} else
890892
#endif
891-
if (strcasecmp(argv[i + 1], "CPU") == 0) {
892-
aoti_device = torch::Device(torch::kCPU);
893-
} else {
894-
fprintf(stderr, "Unknown device %s", argv[i + 1]);
895-
exit(1);
896-
}
893+
if (strcasecmp(argv[i + 1], "CPU") == 0) {
894+
aoti_device = torch::Device(torch::kCPU);
895+
} else {
896+
fprintf(stderr, "Unknown device %s", argv[i + 1]);
897+
exit(1);
898+
}
897899
#endif
898900
} else {
899901
error_usage();

scripts/patch_triton.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,21 @@
33

44

55
from pathlib import Path
6+
67
import triton
78

89

910
def patch_def_search_in_jit_py(jit_py: Path) -> None:
10-
with jit_py.open() as f:
11-
lines = f.readlines()
12-
old_match = 'self.src = self.src[self.src.find("def"):]'
13-
new_match = 'self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():]'
14-
lines.insert(4, "import re\n")
15-
for idx, line in enumerate(lines):
16-
if old_match in line:
17-
lines[idx] = line.replace(old_match, new_match)
18-
jit_py.write_text("".join(lines))
11+
with jit_py.open() as f:
12+
lines = f.readlines()
13+
old_match = 'self.src = self.src[self.src.find("def"):]'
14+
new_match = 'self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():]'
15+
lines.insert(4, "import re\n")
16+
for idx, line in enumerate(lines):
17+
if old_match in line:
18+
lines[idx] = line.replace(old_match, new_match)
19+
jit_py.write_text("".join(lines))
20+
1921

2022
jit_py = Path(triton.__file__).parent / "runtime" / "jit.py"
2123
patch_def_search_in_jit_py(jit_py)
22-

scripts/updown.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ def updown_process_line(
8686
# [ x1 | c2 | x3 ] means "pick one", so we may have to check that and pick one
8787
# of the options. Probably pick the last option because testing has more likely
8888
# been performed with the first option!
89-
last=True
89+
last = True
9090
if last:
91-
line=select_last_option_between_brackets(line)
91+
line = select_last_option_between_brackets(line)
9292
else:
93-
line=select_first_option_between_brackets(line)
94-
93+
line = select_first_option_between_brackets(line)
94+
9595
output(
9696
remove_text_between_brackets(line),
9797
replace_list=replace_list,

torchchat.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
from cli import (
1313
add_arguments_for_verb,
14-
KNOWN_VERBS,
15-
INVENTORY_VERBS,
1614
arg_init,
1715
check_args,
16+
INVENTORY_VERBS,
17+
KNOWN_VERBS,
1818
)
1919

2020
default_device = "cpu"
@@ -34,19 +34,19 @@
3434
subparsers.required = True
3535

3636
VERB_HELP = {
37-
"chat": "Chat interactively with a model",
38-
"browser": "Chat interactively in a browser",
39-
"download": "Download a model from Hugging Face or others",
40-
"generate": "Generate responses from a model given a prompt",
41-
"eval": "Evaluate a model given a prompt",
42-
"export": "Export a model for AOT Inductor or ExecuTorch",
43-
"list": "List supported models",
44-
"remove": "Remove downloaded model artifacts",
45-
"where": "Return directory containing downloaded model artifacts",
37+
"chat": "Chat interactively with a model",
38+
"browser": "Chat interactively in a browser",
39+
"download": "Download a model from Hugging Face or others",
40+
"generate": "Generate responses from a model given a prompt",
41+
"eval": "Evaluate a model given a prompt",
42+
"export": "Export a model for AOT Inductor or ExecuTorch",
43+
"list": "List supported models",
44+
"remove": "Remove downloaded model artifacts",
45+
"where": "Return directory containing downloaded model artifacts",
4646
}
4747
for verb in KNOWN_VERBS:
48-
subparser = subparsers.add_parser(verb, help=VERB_HELP[verb])
49-
add_arguments_for_verb(subparser, verb)
48+
subparser = subparsers.add_parser(verb, help=VERB_HELP[verb])
49+
add_arguments_for_verb(subparser, verb)
5050

5151
# Now parse the arguments
5252
args = parser.parse_args()
@@ -72,7 +72,7 @@
7272
args.gui = True
7373
check_args(args, "browser")
7474

75-
from browser.browser import main as browser_main
75+
from browser.browser import main as browser_main
7676

7777
browser_main(args)
7878
elif args.command == "generate":

utils/measure_time.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from time import perf_counter
22
from typing import Optional
33

4+
45
class measure_time:
5-
def __init__(
6-
self,
7-
message: Optional[str] = 'Time: {time:.3f} seconds'
8-
):
6+
def __init__(self, message: Optional[str] = "Time: {time:.3f} seconds"):
97
self.message = message
108

119
def __enter__(
12-
self,
10+
self,
1311
):
1412
self.start = perf_counter()
1513
self.message

0 commit comments

Comments
 (0)