Skip to content

Commit 0390655

Browse files
committed
test: Add initial unit tests for chat formatters
There's no formal execution framework for pytest yet, but these were helpful in ensuring that the formatting was working correctly! To run them, install pytest and run `pytest tests/` Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 8d26923 commit 0390655

File tree

2 files changed

+239
-0
lines changed

2 files changed

+239
-0
lines changed

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""
2+
Global pytest config, fixtures, and helpers go here!
3+
"""
4+
5+
# Standard
6+
import os
7+
import sys
8+
9+
# Make sure tests can import torchchat
10+
sys.path.append(
11+
os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
12+
)

tests/test_chat_formatters.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""
2+
Unit tests for chat formatters
3+
"""
4+
5+
# Third Party
6+
import pytest
7+
8+
# Local
9+
from torchchat.generate import (
10+
HFTokenizerChatFormatter,
11+
Llama2ChatFormatter,
12+
Llama3ChatFormatter,
13+
)
14+
15+
## Helpers #####################################################################
16+
17+
class DummyTokenizer:
18+
"""Dummy tokenizer that encodes as strings so it's easy to check formatting"""
19+
def encode(self, text, *_, **__):
20+
return text
21+
22+
23+
class DummySPTokenizer(DummyTokenizer):
24+
"""Emulated Sentencepiece tokenizer with bos/eos"""
25+
bos = "<s>"
26+
eos = "</s>"
27+
28+
29+
class DummyLlama3Tokenizer(DummyTokenizer):
30+
class _IdentityDict:
31+
def __getitem__(self, key):
32+
return key
33+
special_tokens = _IdentityDict()
34+
35+
36+
class DummyHFTokenizer(DummyTokenizer):
37+
"""Dummy made up chat template scheme"""
38+
# Sequence
39+
bos = "<bos>"
40+
# Turn
41+
bot = "<bot>"
42+
eot = "<eot>"
43+
# Role
44+
bor = "<bor>"
45+
eor = "<eor>"
46+
def apply_chat_template(self, messages, add_generation_prompt):
47+
out = [self.bos]
48+
role = None
49+
for msg in messages:
50+
role = msg["role"]
51+
content = msg["content"]
52+
out.append(f"{self.bot}{self.bor}{role}{self.eor}{content}{self.eot}")
53+
if add_generation_prompt and role != "assistant":
54+
out.append(f"{self.bot}{self.bor}assistant{self.eor}")
55+
return "\n".join(out)
56+
57+
58+
def check_rendering(fmt, messages, expected, add_generation_prompt):
59+
"""Render messages and compare to expected output"""
60+
assert "".join(fmt.encode_dialog_prompt(messages, add_generation_prompt)) == expected
61+
62+
63+
def make_message(role, text):
64+
return {"role": role, "content": text}
65+
66+
67+
SYSTEM_PROMPT = "You are a helpful assistant, feel free to ask me anything."
68+
USER1 = "Hello world!"
69+
ASSISTANT1 = "Greetings! How can I help you?"
70+
USER2 = "Why is the sky blue?"
71+
ASSISTANT2 = "The sky appears blue because of a phenomenon called Rayleigh scattering."
72+
73+
74+
# Stock sets of messages to test
75+
MSGS_NO_SYS= [
76+
make_message("user", USER1),
77+
]
78+
MSGS_SYS_USR = [
79+
make_message("system", SYSTEM_PROMPT),
80+
make_message("user", USER1),
81+
]
82+
MSGS_SYS_USR_ASST = [
83+
make_message("system", SYSTEM_PROMPT),
84+
make_message("user", USER1),
85+
make_message("assistant", ASSISTANT1),
86+
]
87+
MSGS_MULTI_TURN = [
88+
make_message("system", SYSTEM_PROMPT),
89+
make_message("user", USER1),
90+
make_message("assistant", ASSISTANT1),
91+
make_message("user", USER2),
92+
make_message("assistant", ASSISTANT2),
93+
]
94+
95+
## Llama2ChatFormatter #########################################################
96+
97+
@pytest.mark.parametrize(
98+
["messages", "expected"],
99+
[
100+
# single user message (no system prompt)
101+
(MSGS_NO_SYS, f"<s>[INST] {USER1} [/INST]"),
102+
# sys, usr
103+
(MSGS_SYS_USR, f"""<s>[INST] <<SYS>>
104+
{SYSTEM_PROMPT}
105+
<</SYS>>
106+
107+
{USER1} [/INST]"""),
108+
# sys, usr, asst
109+
(MSGS_SYS_USR_ASST, f"""<s>[INST] <<SYS>>
110+
{SYSTEM_PROMPT}
111+
<</SYS>>
112+
113+
{USER1} [/INST] {ASSISTANT1} </s>
114+
"""),
115+
# sys, usr, asst, usr, asst
116+
(MSGS_MULTI_TURN, f"""<s>[INST] <<SYS>>
117+
{SYSTEM_PROMPT}
118+
<</SYS>>
119+
120+
{USER1} [/INST] {ASSISTANT1} </s>
121+
<s>[INST] {USER2} [/INST] {ASSISTANT2} </s>
122+
"""),
123+
]
124+
)
125+
def test_llama2_chat_formatter(messages, expected):
126+
"""Tests for Llama2 following the official guide
127+
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/
128+
"""
129+
tok = DummySPTokenizer()
130+
fmt = Llama2ChatFormatter(tok)
131+
# NOTE: add_generation_prompt not used by Llama2
132+
check_rendering(fmt, messages, expected, True)
133+
134+
## Llama3ChatFormatter #########################################################
135+
136+
@pytest.mark.parametrize(
137+
["messages", "expected"],
138+
[
139+
# single user message (no system prompt)
140+
(MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
141+
142+
{USER1}<|eot_id|>
143+
"""),
144+
# sys, usr
145+
(MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
146+
147+
{SYSTEM_PROMPT}<|eot_id|>
148+
<|start_header_id|>user<|end_header_id|>
149+
150+
{USER1}<|eot_id|>
151+
"""),
152+
# sys, usr, asst
153+
(MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
154+
155+
{SYSTEM_PROMPT}<|eot_id|>
156+
<|start_header_id|>user<|end_header_id|>
157+
158+
{USER1}<|eot_id|>
159+
<|start_header_id|>assistant<|end_header_id|>
160+
161+
{ASSISTANT1}<|eot_id|>
162+
"""),
163+
# sys, usr, asst, usr, asst
164+
(MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
165+
166+
{SYSTEM_PROMPT}<|eot_id|>
167+
<|start_header_id|>user<|end_header_id|>
168+
169+
{USER1}<|eot_id|>
170+
<|start_header_id|>assistant<|end_header_id|>
171+
172+
{ASSISTANT1}<|eot_id|>
173+
<|start_header_id|>user<|end_header_id|>
174+
175+
{USER2}<|eot_id|>
176+
<|start_header_id|>assistant<|end_header_id|>
177+
178+
{ASSISTANT2}<|eot_id|>
179+
"""),
180+
]
181+
)
182+
@pytest.mark.parametrize("add_generation_prompt", [True, False])
183+
def test_llama3_chat_formatter(messages, expected, add_generation_prompt):
184+
"""Tests for Llama3 following the official guide
185+
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/
186+
"""
187+
tok = DummyLlama3Tokenizer()
188+
fmt = Llama3ChatFormatter(tok)
189+
# No assistant prompt added if the last message is from the assistant
190+
if add_generation_prompt and messages[-1]["role"] != "assistant":
191+
expected += "<|start_header_id|>assistant<|end_header_id|>\n\n"
192+
check_rendering(fmt, messages, expected, add_generation_prompt)
193+
194+
## HFTokenizerChatFormatter ####################################################
195+
196+
@pytest.mark.parametrize(
197+
["messages", "expected"],
198+
[
199+
# single user message (no system prompt)
200+
(MSGS_NO_SYS, f"""<bos>
201+
<bot><bor>user<eor>{USER1}<eot>"""),
202+
# sys, usr
203+
(MSGS_SYS_USR, f"""<bos>
204+
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
205+
<bot><bor>user<eor>{USER1}<eot>"""),
206+
# sys, usr, asst
207+
(MSGS_SYS_USR_ASST, f"""<bos>
208+
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
209+
<bot><bor>user<eor>{USER1}<eot>
210+
<bot><bor>assistant<eor>{ASSISTANT1}<eot>"""),
211+
# sys, usr, asst, usr, asst
212+
(MSGS_MULTI_TURN, f"""<bos>
213+
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
214+
<bot><bor>user<eor>{USER1}<eot>
215+
<bot><bor>assistant<eor>{ASSISTANT1}<eot>
216+
<bot><bor>user<eor>{USER2}<eot>
217+
<bot><bor>assistant<eor>{ASSISTANT2}<eot>"""),
218+
]
219+
)
220+
@pytest.mark.parametrize("add_generation_prompt", [True, False])
221+
def test_hf_chat_formatter(messages, expected, add_generation_prompt):
222+
tok = DummyHFTokenizer()
223+
fmt = HFTokenizerChatFormatter(tok)
224+
# No assistant prompt added if the last message is from the assistant
225+
if add_generation_prompt and messages[-1]["role"] != "assistant":
226+
expected += f"\n{tok.bot}{tok.bor}assistant{tok.eor}"
227+
check_rendering(fmt, messages, expected, add_generation_prompt)

0 commit comments

Comments
 (0)