@@ -61,6 +61,7 @@ def step_server_config(context, server_fqdn, server_port):
61
61
context .server_metrics = False
62
62
context .server_process = None
63
63
context .seed = None
64
+ context .draft = None
64
65
context .server_seed = None
65
66
context .user_api_key = None
66
67
context .response_format = None
@@ -107,6 +108,11 @@ def step_n_gpu_layer(context, ngl):
107
108
context .n_gpu_layer = ngl
108
109
109
110
111
+ @step ('{draft:d} as draft' )
112
+ def step_draft (context , draft ):
113
+ context .draft = draft
114
+
115
+
110
116
@step ('{n_ctx:d} KV cache size' )
111
117
def step_n_ctx (context , n_ctx ):
112
118
context .n_ctx = n_ctx
@@ -254,6 +260,15 @@ def step_n_tokens_predicted(context, predicted_n):
254
260
assert_n_tokens_predicted (context .completion , predicted_n )
255
261
256
262
263
+ @step ('all predictions are equal' )
264
+ @async_run_until_complete
265
+ async def step_predictions_equal (context ):
266
+ n_completions = await gather_tasks_results (context )
267
+ assert n_completions >= 2 , "need at least 2 completions"
268
+ assert_all_predictions_equal (context .tasks_result )
269
+ context .tasks_result = []
270
+
271
+
257
272
@step ('the completion is truncated' )
258
273
def step_assert_completion_truncated (context ):
259
274
step_assert_completion_truncated (context , '' )
@@ -1020,6 +1035,23 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
1020
1035
assert n_predicted == expected_predicted_n , (f'invalid number of tokens predicted:'
1021
1036
f' { n_predicted } <> { expected_predicted_n } ' )
1022
1037
1038
+ def assert_all_predictions_equal (completion_responses ):
1039
+ content_0 = completion_responses [0 ]['content' ]
1040
+
1041
+ if 'DEBUG' in os .environ and os .environ ['DEBUG' ] == 'ON' :
1042
+ print (f"content 0: { content_0 } " )
1043
+
1044
+ i = 1
1045
+ for response in completion_responses [1 :]:
1046
+ content = response ['content' ]
1047
+
1048
+ if 'DEBUG' in os .environ and os .environ ['DEBUG' ] == 'ON' :
1049
+ print (f"content { i } : { content } " )
1050
+
1051
+ assert content == content_0 , "contents not equal"
1052
+
1053
+ i += 1
1054
+
1023
1055
1024
1056
async def gather_tasks_results (context ):
1025
1057
n_tasks = len (context .concurrent_tasks )
@@ -1148,6 +1180,8 @@ def start_server_background(context):
1148
1180
server_args .extend (['--ubatch-size' , context .n_ubatch ])
1149
1181
if context .n_gpu_layer :
1150
1182
server_args .extend (['--n-gpu-layers' , context .n_gpu_layer ])
1183
+ if context .draft is not None :
1184
+ server_args .extend (['--draft' , context .draft ])
1151
1185
if context .server_continuous_batching :
1152
1186
server_args .append ('--cont-batching' )
1153
1187
if context .server_embeddings :
0 commit comments