Skip to content

Commit 7650153

Browse files
committed
Add default value to add_generation_prompt to preserve bc
1 parent 7ac16f9 commit 7650153

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torchchat/generate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(self, tokenizer):
8585
def encode_dialog_prompt(
8686
self,
8787
dialog: DIALOG_TYPE,
88-
add_generation_prompt: bool,
88+
add_generation_prompt: bool = True,
8989
) -> List[int]:
9090
"""Encode a sequence of messages into a sequence of token IDs, including
9191
the chat template
@@ -136,7 +136,7 @@ def _encode_message(self, message: _ChatFormatter.MESSAGE_TYPE) -> List[int]:
136136
def encode_dialog_prompt(
137137
self,
138138
dialog: _ChatFormatter.DIALOG_TYPE,
139-
add_generation_prompt: bool,
139+
add_generation_prompt: bool = True,
140140
) -> List[int]:
141141
tokens = []
142142
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
@@ -166,7 +166,7 @@ def _get_content_str(message: _ChatFormatter.MESSAGE_TYPE) -> str:
166166
def encode_dialog_prompt(
167167
self,
168168
dialog: _ChatFormatter.DIALOG_TYPE,
169-
add_generation_prompt: bool, # UNUSED
169+
add_generation_prompt: bool = True, # UNUSED
170170
) -> List[int]:
171171
new_turn = True
172172
tokens = []
@@ -197,7 +197,7 @@ class HFTokenizerChatFormatter(_ChatFormatter):
197197
def encode_dialog_prompt(
198198
self,
199199
dialog: _ChatFormatter.DIALOG_TYPE,
200-
add_generation_prompt: bool,
200+
add_generation_prompt: bool = True,
201201
) -> List[int]:
202202
rendered = self.tokenizer.apply_chat_template(
203203
dialog, add_generation_prompt=add_generation_prompt

0 commit comments

Comments
 (0)