Skip to content

Commit a1baf0d

Browse files
committed
Remove msg from assert_close
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent d066a25 commit a1baf0d

File tree

1 file changed

+4
-20
lines changed

1 file changed

+4
-20
lines changed

extension/llm/modules/test/test_attention.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,7 @@ def test_attention_eager(self):
9595
et_res = self.et_mha(self.x, self.x) # Self attention.
9696
tt_res = self.tt_mha(self.x, self.x) # Self attention.
9797

98-
assert_close(
99-
et_res,
100-
tt_res,
101-
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}",
102-
)
98+
assert_close(et_res, tt_res)
10399

104100
# test with kv cache
105101
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
@@ -130,11 +126,7 @@ def test_attention_eager(self):
130126
self.x, self.x, input_pos=next_input_pos
131127
) # Self attention with input pos.
132128

133-
assert_close(
134-
et_res,
135-
tt_res,
136-
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}",
137-
)
129+
assert_close(et_res, tt_res)
138130

139131
def test_attention_export(self):
140132
# Self attention.
@@ -147,11 +139,7 @@ def test_attention_export(self):
147139
et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
148140
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
149141

150-
assert_close(
151-
et_res,
152-
tt_res,
153-
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}",
154-
)
142+
assert_close(et_res, tt_res)
155143

156144
# TODO: KV cache.
157145

@@ -177,10 +165,6 @@ def test_attention_executorch(self):
177165
et_res = method.execute((self.x, self.x, self.input_pos))
178166
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
179167

180-
assert_close(
181-
et_res[0],
182-
tt_res,
183-
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res[0]}",
184-
)
168+
assert_close(et_res[0], tt_res)
185169

186170
# TODO: KV cache.

0 commit comments

Comments
 (0)