File tree Expand file tree Collapse file tree 1 file changed +16
-3
lines changed Expand file tree Collapse file tree 1 file changed +16
-3
lines changed Original file line number Diff line number Diff line change @@ -928,9 +928,22 @@ def chat(
928
928
self .model_forward , fullgraph = True , ** kwargs
929
929
)
930
930
931
- self .decode_one_token = torch .compile (
932
- self .decode_one_token , fullgraph = True , ** kwargs
933
- )
931
+ if self .model .config .model_type == ModelType .Flamingo :
932
+ # Based on https://github.com/pytorch/torchtune/blob/57ab583c84c4a9dcacac23aeabc81f2a679670fe/torchtune/training/_compile.py#L42-L52
933
+ from torchtune .modules import (
934
+ TransformerCrossAttentionLayer ,
935
+ TransformerSelfAttentionLayer ,
936
+ )
937
+ decoder = self .model .model .decoder
938
+ for m in reversed (list (decoder .modules ())):
939
+ if isinstance (m , TransformerSelfAttentionLayer ) or isinstance (
940
+ m , TransformerCrossAttentionLayer
941
+ ):
942
+ m .compile ()
943
+ else :
944
+ self .decode_one_token = torch .compile (
945
+ self .decode_one_token , fullgraph = True , ** kwargs
946
+ )
934
947
935
948
if generator_args .compile_prefill :
936
949
self .prefill = torch .compile (
You can’t perform that action at this time.
0 commit comments