File tree Expand file tree Collapse file tree 2 files changed +7
-2
lines changed
source_transformation/torchtune Expand file tree Collapse file tree 2 files changed +7
-2
lines changed Original file line number Diff line number Diff line change 70
70
replace_sdpa_with_simple_sdpa ,
71
71
)
72
72
73
- from .source_transformation .vulkan_rope import replace_with_vulkan_rotary_emb
74
-
75
73
from .source_transformation .torchtune .attention import replace_mha_with_inference_mha
76
74
75
+ from .source_transformation .vulkan_rope import replace_with_vulkan_rotary_emb
76
+
77
77
78
78
IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
79
79
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -1019,4 +1019,8 @@ def _get_source_transforms( # noqa
1019
1019
if args .vulkan :
1020
1020
transforms .append (replace_with_vulkan_rotary_emb )
1021
1021
1022
+ print (
1023
+ f"Performing the following source transformations: { [transform .__name__ for transform in transforms ]} "
1024
+ )
1025
+
1022
1026
return transforms
Original file line number Diff line number Diff line change @@ -32,6 +32,7 @@ def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None:
32
32
else :
33
33
replace_mha_with_inference_mha (child )
34
34
35
+
35
36
def replace_mha_with_inference_mha (module : torch .nn .Module ) -> torch .nn .Module :
36
37
"""
37
38
Replace TorchTune's MHA with an inference friendly version of MHA that
You can’t perform that action at this time.
0 commit comments