Skip to content

Commit 1371a41

Browse files
authored
Add Initial Compile for Llama 3.2 11B: Decoder TransformerSelfAttentionLayer, TransformerCrossAttentionLayer (#1287)
1 parent 438ebb1 commit 1371a41

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

torchchat/generate.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -928,9 +928,22 @@ def chat(
928928
self.model_forward, fullgraph=True, **kwargs
929929
)
930930

931-
self.decode_one_token = torch.compile(
932-
self.decode_one_token, fullgraph=True, **kwargs
933-
)
931+
if self.model.config.model_type == ModelType.Flamingo:
932+
# Based on https://github.com/pytorch/torchtune/blob/57ab583c84c4a9dcacac23aeabc81f2a679670fe/torchtune/training/_compile.py#L42-L52
933+
from torchtune.modules import (
934+
TransformerCrossAttentionLayer,
935+
TransformerSelfAttentionLayer,
936+
)
937+
decoder = self.model.model.decoder
938+
for m in reversed(list(decoder.modules())):
939+
if isinstance(m, TransformerSelfAttentionLayer) or isinstance(
940+
m, TransformerCrossAttentionLayer
941+
):
942+
m.compile()
943+
else:
944+
self.decode_one_token = torch.compile(
945+
self.decode_one_token, fullgraph=True, **kwargs
946+
)
934947

935948
if generator_args.compile_prefill:
936949
self.prefill = torch.compile(

0 commit comments

Comments
 (0)