@@ -531,15 +531,22 @@ def decode_n_tokens(
531
531
callback = lambda _ : _ ,
532
532
eos_token_id : int = 2 ,
533
533
eot_id : Optional [int ] = None ,
534
+ attention_backend : str = "math" ,
534
535
** sampling_kwargs ,
535
536
):
536
537
new_tokens , new_probs = [], []
537
538
encountered_eos = False
539
+ sdp_backend_dict = {
540
+ 'math' : torch .nn .attention .SDPBackend .MATH ,
541
+ 'flash_attention' : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
542
+ 'efficient_attention' : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
543
+ 'cudnn_attention' : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
544
+ }
538
545
for _i in range (
539
546
num_new_tokens - 1
540
547
): # -1 to save space to run an EoS if dont generate it naturally
541
548
# Actually better for Inductor to codegen attention here
542
- with torch .nn .attention .sdpa_kernel ([torch . nn . attention . SDPBackend . MATH ]):
549
+ with torch .nn .attention .sdpa_kernel ([sdp_backend_dict [ attention_backend ] ]):
543
550
544
551
out_token = cur_token .clone ()
545
552
next_token , next_prob = self .decode_one_token (
@@ -683,6 +690,7 @@ def generate(
683
690
sequential_prefill = True ,
684
691
callback = lambda x : x ,
685
692
max_seq_length : int ,
693
+ attention_backend : str = "math" ,
686
694
seed : Optional [int ] = None ,
687
695
** sampling_kwargs ,
688
696
) -> torch .Tensor :
@@ -799,6 +807,7 @@ def generate(
799
807
if self .is_llama3_model
800
808
else None
801
809
),
810
+ attention_backend = attention_backend ,
802
811
** sampling_kwargs ,
803
812
):
804
813
generated_tokens .append (generated_token .view (- 1 ))
@@ -1170,6 +1179,10 @@ def callback(x, *, done_generating=False):
1170
1179
prof = torch .profiler .profile ()
1171
1180
t0 = time .perf_counter ()
1172
1181
num_tokens_generated = 0
1182
+ if self .builder_args .device == "cpu" and (self .builder_args .attention_backend == "efficient_attention"
1183
+ or self .builder_args .attention_backend == "cudnn_attention" ):
1184
+ print (f"Warning: { self .builder_args .attention_backend } is not supported on CPU. Using math instead." )
1185
+ self .builder_args .attention_backend = "math"
1173
1186
with prof :
1174
1187
generator_func = self .generate (
1175
1188
self .model ,
@@ -1186,6 +1199,7 @@ def callback(x, *, done_generating=False):
1186
1199
start_pos = start_pos ,
1187
1200
skip_cache_setup = not is_first_sample ,
1188
1201
max_seq_length = max_seq_length ,
1202
+ attention_backend = self .builder_args .attention_backend ,
1189
1203
)
1190
1204
for token_tensor , metrics in generator_func :
1191
1205
if token_tensor is not None :
0 commit comments