Skip to content

Commit 0b4b56a

Browse files
GregoryComermalfet
authored andcommitted
Implement list and remove commands (#443)
1 parent 485a268 commit 0b4b56a

File tree

6 files changed

+127
-20
lines changed

6 files changed

+127
-20
lines changed

.github/workflows/pull.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,10 @@ jobs:
364364
- name: Test download
365365
run: |
366366
367+
python torchchat.py list
368+
python torchchat.py download stories15m
367369
python torchchat.py generate stories15M
370+
python torchchat.py remove stories15m
368371
369372
test-tinystories-eager:
370373
strategy:

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ HuggingFace.
5050
python3 torchchat.py download llama3
5151
```
5252

53+
View available models with `python3 torchchat.py list`. You can also remove downloaded models
54+
with `python3 torchchat.py remove llama3`.
55+
5356
## What can you do with torchchat?
5457

5558
* Run models via PyTorch / Python:

cli.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def check_args(args, name: str) -> None:
2121
# Handle model download. Skip this for download, since it has slightly
2222
# different semantics.
2323
if (
24-
name != "download"
24+
name not in ["download", "list", "remove"]
2525
and args.model
2626
and not is_model_downloaded(args.model, args.model_directory)
2727
):
@@ -61,6 +61,13 @@ def add_arguments_for_export(parser):
6161
# Only export specific options should be here
6262
_add_arguments_common(parser)
6363

64+
def add_arguments_for_list(parser):
65+
# Only list specific options should be here
66+
_add_arguments_common(parser)
67+
68+
def add_arguments_for_remove(parser):
69+
# Only remove specific options should be here
70+
_add_arguments_common(parser)
6471

6572
def _add_arguments_common(parser):
6673
# Model specification. TODO Simplify this.

config/model_config.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,31 +53,35 @@ class ModelConfig:
5353
model_configs: Dict[str, ModelConfig] = None
5454

5555

56-
def resolve_model_config(model: str) -> ModelConfig:
56+
def load_model_configs() -> Dict[str, ModelConfig]:
5757
global model_aliases
5858
global model_configs
5959

60-
model = model.lower()
60+
model_aliases = {}
61+
model_configs = {}
6162

62-
# Lazy load model config from JSON.
63-
if not model_configs:
64-
model_aliases = {}
65-
model_configs = {}
63+
with open(
64+
Path(__file__).parent.parent / "config" / "data" / "models.json", "r"
65+
) as f:
66+
model_config_dict = json.load(f)
67+
68+
for key, value in model_config_dict.items():
69+
config = ModelConfig(**value)
70+
config.name = key
6671

67-
with open(
68-
Path(__file__).parent.parent / "config" / "data" / "models.json", "r"
69-
) as f:
70-
model_config_dict = json.load(f)
72+
key = key.lower()
73+
model_configs[key] = config
7174

72-
for key, value in model_config_dict.items():
73-
config = ModelConfig(**value)
74-
config.name = key
75+
for alias in config.aliases:
76+
model_aliases[alias.lower()] = key
7577

76-
key = key.lower()
77-
model_configs[key] = config
78+
return model_configs
7879

79-
for alias in config.aliases:
80-
model_aliases[alias.lower()] = key
80+
def resolve_model_config(model: str) -> ModelConfig:
81+
model = model.lower()
82+
# Lazy load model config from JSON.
83+
if not model_configs:
84+
load_model_configs()
8185

8286
if model in model_aliases:
8387
model = model_aliases[model]

download.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from config.model_config import (
1414
ModelConfig,
1515
ModelDistributionChannel,
16+
load_model_configs,
1617
resolve_model_config,
1718
)
1819

@@ -107,5 +108,70 @@ def is_model_downloaded(model: str, models_dir: Path) -> bool:
107108
return os.path.isdir(model_dir) and os.listdir(model_dir)
108109

109110

110-
def main(args):
111+
# Subcommand to list available models.
112+
def list_main(args) -> None:
113+
# TODO It would be nice to have argparse validate this. However, we have
114+
# model as an optional named parameter for all subcommands, so we'd
115+
# probably need to move it to be registered per-command.
116+
if args.model:
117+
print("Usage: torchchat.py list")
118+
return
119+
120+
model_configs = load_model_configs()
121+
122+
# Build the table in-memory so that we can align the text nicely.
123+
name_col = []
124+
aliases_col = []
125+
installed_col = []
126+
127+
for name, config in model_configs.items():
128+
is_downloaded = is_model_downloaded(name, args.model_directory)
129+
130+
name_col.append(name)
131+
aliases_col.append(", ".join(config.aliases))
132+
installed_col.append('Yes' if is_downloaded else "")
133+
134+
cols = {
135+
"Model": name_col,
136+
"Aliases": aliases_col,
137+
"Downloaded": installed_col
138+
}
139+
140+
# Find the length of the longest value in each column.
141+
col_widths = {key:max(*[len(s) for s in vals], len(key)) + 1 for (key,vals) in cols.items()}
142+
143+
# Display header.
144+
print()
145+
print(*[val.ljust(width) for (val, width) in col_widths.items()])
146+
print(*["-" * width for width in col_widths.values()])
147+
148+
for i in range(len(name_col)):
149+
row = [col[i] for col in cols.values()]
150+
print(*[val.ljust(width) for (val, width) in zip(row, col_widths.values())])
151+
print()
152+
153+
154+
# Subcommand to remove downloaded model artifacts.
155+
def remove_main(args) -> None:
156+
# TODO It would be nice to have argparse validate this. However, we have
157+
# model as an optional named parameter for all subcommands, so we'd
158+
# probably need to move it to be registered per-command.
159+
if not args.model:
160+
print("Usage: torchchat.py remove <model-or-alias>")
161+
return
162+
163+
model_config = resolve_model_config(args.model)
164+
model_dir = args.model_directory / model_config.name
165+
166+
if not os.path.isdir(model_dir):
167+
print(f"Model {args.model} has no downloaded artifacts.")
168+
return
169+
170+
print(f"Removing downloaded model artifacts for {args.model}...")
171+
shutil.rmtree(model_dir)
172+
print("Done.")
173+
174+
175+
# Subcommand to download model artifacts.
176+
def download_main(args) -> None:
111177
download_and_convert(args.model, args.model_directory, args.hf_token)

torchchat.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
add_arguments_for_eval,
1818
add_arguments_for_export,
1919
add_arguments_for_generate,
20+
add_arguments_for_list,
21+
add_arguments_for_remove,
2022
arg_init,
2123
check_args,
2224
)
@@ -76,6 +78,18 @@
7678
)
7779
add_arguments_for_export(parser_export)
7880

81+
parser_list = subparsers.add_parser(
82+
"list",
83+
help="List supported models",
84+
)
85+
add_arguments_for_list(parser_list)
86+
87+
parser_remove = subparsers.add_parser(
88+
"remove",
89+
help="Remove downloaded model artifacts",
90+
)
91+
add_arguments_for_remove(parser_remove)
92+
7993
# Move all flags to the front of sys.argv since we don't
8094
# want to use the subparser syntax
8195
flag_args = []
@@ -143,7 +157,7 @@
143157
subprocess.run(command)
144158
elif args.command == "download":
145159
check_args(args, "download")
146-
from download import main as download_main
160+
from download import download_main
147161

148162
download_main(args)
149163
elif args.command == "generate":
@@ -160,5 +174,15 @@
160174
from export import main as export_main
161175

162176
export_main(args)
177+
elif args.command == "list":
178+
check_args(args, "list")
179+
from download import list_main
180+
181+
list_main(args)
182+
elif args.command == "remove":
183+
check_args(args, "remove")
184+
from download import remove_main
185+
186+
remove_main(args)
163187
else:
164188
parser.print_help()

0 commit comments

Comments
 (0)