We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 587f2f8 commit e7d39c2Copy full SHA for e7d39c2
examples/models/llama/export_llama_lib.py
@@ -1273,6 +1273,7 @@ def _get_source_transforms( # noqa
1273
preq_mode: Optional[str] = None,
1274
preq_group_size: Optional[int] = None,
1275
preq_embedding_quantize: Optional[str] = None,
1276
+ local_global_attention: Optional[List[int]] = None,
1277
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
1278
"""
1279
Return a list of functions that transform a graph.
@@ -1467,7 +1468,7 @@ class Args:
1467
1468
if vulkan:
1469
transforms.append(replace_with_vulkan_rotary_emb)
1470
- if getattr(args, "local_global_attention", None) is not None:
1471
+ if local_global_attention:
1472
transforms.append(
1473
partial(
1474
replace_kv_cache_with_ring_kv_cache,
0 commit comments