Skip to content

Commit 4587852

Browse files
committed
Lint and print
1 parent 1dd12f0 commit 4587852

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@
7070
replace_sdpa_with_simple_sdpa,
7171
)
7272

73-
from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb
74-
7573
from .source_transformation.torchtune.attention import replace_mha_with_inference_mha
7674

75+
from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb
76+
7777

7878
IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
7979
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -1019,4 +1019,8 @@ def _get_source_transforms( # noqa
10191019
if args.vulkan:
10201020
transforms.append(replace_with_vulkan_rotary_emb)
10211021

1022+
print(
1023+
f"Performing the following source transformations: {[transform.__name__ for transform in transforms]}"
1024+
)
1025+
10221026
return transforms

examples/models/llama/source_transformation/torchtune/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None:
3232
else:
3333
replace_mha_with_inference_mha(child)
3434

35+
3536
def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module:
3637
"""
3738
Replace TorchTune's MHA with an inference friendly version of MHA that

0 commit comments

Comments
 (0)