Skip to content

Commit 6556bd1

Browse files
committed
Fix broken internal tests
Summary: Fix land-time race condition caused by two related diffs landing at the same time (D73800023 and D73891423) Reviewed By: larryliu0820 Differential Revision: D74669998
1 parent a21022c commit 6556bd1

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,7 @@ def _get_source_transforms( # noqa
12731273
preq_mode: Optional[str] = None,
12741274
preq_group_size: Optional[int] = None,
12751275
preq_embedding_quantize: Optional[str] = None,
1276+
local_global_attention: Optional[List[int]] = None,
12761277
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
12771278
"""
12781279
Return a list of functions that transform a graph.
@@ -1467,7 +1468,7 @@ class Args:
14671468
if vulkan:
14681469
transforms.append(replace_with_vulkan_rotary_emb)
14691470

1470-
if getattr(args, "local_global_attention", None) is not None:
1471+
if local_global_attention:
14711472
transforms.append(
14721473
partial(
14731474
replace_kv_cache_with_ring_kv_cache,

0 commit comments

Comments
 (0)