@@ -67,12 +67,13 @@ def generate( # noqa: C901
67
67
temperature : float = 0.8 ,
68
68
top_p : float = 0.9 ,
69
69
echo : bool = False ,
70
+ pos_base : int = 0 ,
70
71
) -> List [int ]:
71
72
# prefill
72
73
logits = self .forward (
73
74
tokens = torch .tensor ([prompt_tokens ], dtype = torch .long , device = self .device ),
74
75
input_pos = (
75
- torch .tensor ([0 ], dtype = torch .long , device = self .device )
76
+ torch .tensor ([pos_base ], dtype = torch .long , device = self .device )
76
77
if self .params .use_kv_cache
77
78
else None
78
79
),
@@ -89,7 +90,9 @@ def generate( # noqa: C901
89
90
[[current_token ]], dtype = torch .long , device = self .device
90
91
),
91
92
input_pos = torch .tensor (
92
- [len (tokens ) - 1 ], dtype = torch .long , device = self .device
93
+ [pos_base + len (tokens ) - 1 ],
94
+ dtype = torch .long ,
95
+ device = self .device ,
93
96
),
94
97
)
95
98
else :
@@ -136,3 +139,49 @@ def text_completion(
136
139
top_p = top_p ,
137
140
echo = echo ,
138
141
)
142
+
143
+ def chat_completion (
144
+ self ,
145
+ temperature : float = 0.6 ,
146
+ top_p : float = 0.9 ,
147
+ ) -> List [int ]:
148
+ """
149
+ Perform multi-turn chat with the language model.
150
+
151
+ Args:
152
+ prompt (str): Text prompt for completion.
153
+ temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
154
+ top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
155
+ echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
156
+
157
+ Returns:
158
+ Generated list of tokens.
159
+
160
+ Note:
161
+ This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
162
+ """
163
+ exit_prompt = "exit"
164
+ tokens = []
165
+ prompt = input ("Me: " )
166
+ while prompt and prompt != exit_prompt :
167
+ print ("LLM: " , end = "" , flush = True )
168
+ new_tokens = self .generate (
169
+ prompt_tokens = self .tokenizer .encode (
170
+ self ._format_prompt (prompt ), bos = True , eos = False
171
+ ),
172
+ temperature = temperature ,
173
+ top_p = top_p ,
174
+ echo = True ,
175
+ pos_base = len (tokens ),
176
+ )
177
+ tokens .extend (new_tokens )
178
+ prompt = input ("Me: " )
179
+ return tokens
180
+
181
+ def _format_prompt (self , prompt : str ) -> str :
182
+ return f"""
183
+ <|begin_of_text|><|start_header_id|>system<|end_header_id|>
184
+
185
+ You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
186
+
187
+ { prompt } <|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
0 commit comments