We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 370f526 commit 310b3a3Copy full SHA for 310b3a3
examples/models/llama/export_llama_lib.py
@@ -962,6 +962,7 @@ def _get_source_transforms( # noqa
962
if args.expand_rope_table:
963
transforms.append(materialze_broadcast_of_rope_freq_cis)
964
965
+ transforms.append(replace_mha_with_inference_mha)
966
if args.use_sdpa_with_kv_cache:
967
if is_torchtune_model:
968
assert (
examples/models/llama2/source_transformation/torchtune/attention.py renamed to examples/models/llama/source_transformation/torchtune/attention.py
examples/models/llama2/source_transformation/torchtune/modules/mha.py renamed to examples/models/llama/source_transformation/torchtune/modules/mha.py
0 commit comments