|
24 | 24 |
|
25 | 25 | from PIL import Image
|
26 | 26 |
|
| 27 | +# torchtune model definition dependencies |
| 28 | +from torchtune.data import Message, padded_collate_tiled_images_and_mask |
| 29 | + |
| 30 | +from torchtune.generation import sample as tune_sample |
| 31 | +from torchtune.models.llama3 import llama3_tokenizer |
| 32 | + |
| 33 | +from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
| 34 | +from torchtune.training import set_default_dtype |
| 35 | + |
27 | 36 | from torchchat.cli.builder import (
|
28 | 37 | _initialize_model,
|
29 | 38 | _initialize_tokenizer,
|
|
35 | 44 | from torchchat.utils.build_utils import device_sync, set_precision
|
36 | 45 | from torchchat.utils.device_info import get_device_info
|
37 | 46 |
|
38 |
| -# torchtune model definition dependencies |
39 |
| -from torchtune.data import Message, padded_collate_tiled_images_and_mask |
40 |
| - |
41 |
| -from torchtune.generation import sample as tune_sample |
42 |
| -from torchtune.models.llama3 import llama3_tokenizer |
43 |
| - |
44 |
| -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
45 |
| -from torchtune.training import set_default_dtype |
46 |
| - |
47 | 47 |
|
48 | 48 | class _ChatFormatter(ABC):
|
49 | 49 | def __init__(self, tokenizer):
|
@@ -1155,13 +1155,9 @@ def callback(x, *, done_generating=False):
|
1155 | 1155 | print(
|
1156 | 1156 | f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds"
|
1157 | 1157 | )
|
1158 |
| - aggregate_metrics["tokens_per_sec_jit_compile"] = tokens_sec |
1159 |
| - # Don't continue here.... because we need to report and reset |
1160 |
| - # continue |
1161 |
| - else: |
1162 |
| - aggregate_metrics["tokens_per_sec"].append(tokens_sec) |
1163 |
| - aggregate_metrics["first_token_per_sec"].append(first_token_sec) |
1164 |
| - aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) |
| 1158 | + aggregate_metrics["tokens_per_sec"].append(tokens_sec) |
| 1159 | + aggregate_metrics["first_token_per_sec"].append(first_token_sec) |
| 1160 | + aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) |
1165 | 1161 |
|
1166 | 1162 | logging.info(
|
1167 | 1163 | f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\
|
|
0 commit comments