1
+ import socket
2
+ import threading
3
+ import time
4
+ from contextlib import closing
5
+
1
6
import openai
2
7
import requests
3
- from behave import *
8
+ from behave import step
9
+ from behave .api .async_step import async_run_until_complete
10
+
11
+ base_fqdn = 'localhost'
12
+ base_port = 8080
13
+ base_url = f"http://{ base_fqdn } :{ base_port } "
4
14
5
15
openai .api_key = 'llama.cpp'
6
- openai .api_base = "http://localhost:8080/v1/chat"
16
+ openai .api_base = f"{ base_url } /v1/chat"
17
+
18
+ slow_prompt = 'say hello ' * 10
19
+ fast_prompt = 'Write a joke'
20
+
21
+ n_slots = 2
22
+
23
+
24
+ @step (u'wait for the server to be started' )
25
+ def step_wait_for_the_server_to_be_started (context ):
26
+ server_started = False
27
+ while not server_started :
28
+ with closing (socket .socket (socket .AF_INET , socket .SOCK_STREAM )) as sock :
29
+ result = sock .connect_ex ((base_fqdn , base_port ))
30
+ if result != 0 :
31
+ print ("server not ready: " , base_fqdn , base_port , result )
32
+ time .sleep (1 )
33
+ else :
34
+ return 0
35
+
36
+
37
+ @step (u'wait for the server to be healthy' )
38
+ def step_wait_for_the_server_to_be_healthy (context ):
39
+ status_code = 500
40
+ while status_code != 200 :
41
+ status_code = requests .get (f'{ base_url } /health' ).status_code
42
+ if status_code != 200 :
43
+ time .sleep (1 )
7
44
8
45
9
- @given (u'a prompt {prompt}' )
46
+ @step (u'an health liveness probe' )
47
+ def step_an_health_liveness_probe (context ):
48
+ response = requests .get (f'{ base_url } /health' )
49
+ context .status_code = response .status_code
50
+ context .response_data = response .json ()
51
+
52
+
53
+ @step (u'the server must be healthy' )
54
+ def step_server_healthy (context ):
55
+ assert context .status_code == 200
56
+ assert context .response_data ['status' ] == 'ok'
57
+
58
+
59
+ @step (u'the server is overloaded' )
60
+ @async_run_until_complete ()
61
+ async def step_server_overloaded (context ):
62
+ response = requests .get (f'{ base_url } /health?fail_on_no_slot' )
63
+ assert response .status_code == 503
64
+ assert response .json ()['status' ] == 'no slot available'
65
+
66
+
67
+ @step (u'a prompt {prompt}' )
10
68
def step_prompt (context , prompt ):
11
69
context .prompt = prompt
12
70
13
71
14
- @when (u'we request a completion' )
72
+ @step (u'we request a completion' )
15
73
def step_request_completion (context ):
16
- response = requests .post ('http://localhost:8080 /completion' , json = {
74
+ response = requests .post (f' { base_url } /completion' , json = {
17
75
"prompt" : context .prompt
18
76
})
19
77
status_code = response .status_code
20
78
assert status_code == 200
21
79
context .response_data = response .json ()
22
80
23
81
24
- @then (u'tokens are predicted' )
82
+ @step (u'tokens are predicted' )
25
83
def step_request_completion (context ):
26
- assert len (context .response_data ['content' ]) > 0
27
- assert context .response_data ['timings' ]['predicted_n' ] > 0
84
+ prompt_predicted (context .response_data )
28
85
29
86
30
- @given (u'a user prompt {user_prompt}' )
87
+ @step (u'a user prompt {user_prompt}' )
31
88
def step_user_prompt (context , user_prompt ):
32
89
context .user_prompt = user_prompt
33
90
34
91
35
- @given (u'a system prompt {system_prompt}' )
92
+ @step (u'a system prompt {system_prompt}' )
36
93
def step_system_prompt (context , system_prompt ):
37
94
context .system_prompt = system_prompt
38
95
39
96
40
- @given (u'a model {model}' )
97
+ @step (u'a model {model}' )
41
98
def step_model (context , model ):
42
99
context .model = model
43
100
44
101
45
- @when (u'we request the oai completions endpoint' )
102
+ @step (u'we request the oai completions endpoint' )
46
103
def step_oai_completions (context ):
47
104
context .chat_completion = openai .Completion .create (
48
105
messages = [
@@ -59,8 +116,67 @@ def step_oai_completions(context):
59
116
)
60
117
61
118
62
- @then (u'the oai response contains completion tokens' )
119
+ @step (u'the oai response contains completion tokens' )
63
120
def step_oai_response_has_completion_tokens (context ):
64
121
assert len (context .chat_completion .choices ) == 1
65
122
assert len (context .chat_completion .choices [0 ].message ) > 0
66
123
assert context .chat_completion .usage .completion_tokens > 0
124
+
125
+
126
+ def async_prompt (context , prompt ):
127
+ response = requests .post (f'{ base_url } /completion' , json = {
128
+ "prompt" : prompt
129
+ })
130
+
131
+ context .async_responses .append (response )
132
+
133
+
134
+ @step (u'{n_prompt} {prompt_type} concurrent prompts' )
135
+ def step_n_concurrent_prompts (context , n_prompt , prompt_type ):
136
+ prompt = fast_prompt
137
+ if prompt_type == 'slow' :
138
+ prompt = slow_prompt
139
+ context .async_responses = []
140
+ context .threads = []
141
+ for i in range (int (n_prompt )):
142
+ thread = threading .Thread (target = async_prompt , args = (context , prompt ))
143
+ thread .start ()
144
+ context .threads .append (thread )
145
+
146
+
147
+ def wait_for_slots_processing (context , expected_slots_processing ):
148
+ while True :
149
+ health = requests .get (f'{ base_url } /health' ).json ()
150
+ if 'slots_processing' in health : # FIXME when #5594 is merged
151
+ slots_processing = health ['slots_processing' ]
152
+ else :
153
+ slots_processing = 0
154
+ if slots_processing == expected_slots_processing :
155
+ break
156
+ else :
157
+ time .sleep (0.2 )
158
+
159
+
160
+ @step (u'wait for all slots processing' )
161
+ def step_wait_for_all_slots_processing (context ):
162
+ wait_for_slots_processing (context , n_slots )
163
+
164
+
165
+ @step (u'wait for all slots idle' )
166
+ def step_wait_for_all_slots_idle (context ):
167
+ wait_for_slots_processing (context , 0 )
168
+
169
+
170
+ @step (u'all prompts must be predicted' )
171
+ def step_all_prompts_must_be_predicted (context ):
172
+ for thread in context .threads :
173
+ thread .join ()
174
+ for async_response in context .async_responses :
175
+ assert async_response .status_code == 200
176
+ response_data = async_response .json ()
177
+ prompt_predicted (response_data )
178
+
179
+
180
+ def prompt_predicted (response_data ):
181
+ assert len (response_data ['content' ]) > 0
182
+ assert response_data ['timings' ]['predicted_n' ] > 0
0 commit comments