Skip to content

Commit e7d39c2

Browse files
authored
Fix broken tests
Differential Revision: D74669998 Pull Request resolved: #10866
1 parent 587f2f8 commit e7d39c2

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,7 @@ def _get_source_transforms( # noqa
12731273
preq_mode: Optional[str] = None,
12741274
preq_group_size: Optional[int] = None,
12751275
preq_embedding_quantize: Optional[str] = None,
1276+
local_global_attention: Optional[List[int]] = None,
12761277
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
12771278
"""
12781279
Return a list of functions that transform a graph.
@@ -1467,7 +1468,7 @@ class Args:
14671468
if vulkan:
14681469
transforms.append(replace_with_vulkan_rotary_emb)
14691470

1470-
if getattr(args, "local_global_attention", None) is not None:
1471+
if local_global_attention:
14711472
transforms.append(
14721473
partial(
14731474
replace_kv_cache_with_ring_kv_cache,

0 commit comments

Comments
 (0)