We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3fd9963 commit 4af2c24Copy full SHA for 4af2c24
extension/llm/modules/test/test_attention.py
@@ -167,9 +167,9 @@ def test_attention_aoti(self):
167
path = package_aoti(os.path.join(tempdir, "mha.pt2"), so)
168
mha_aoti = load_package(path)
169
170
- et_res = mha_aoti(self.x, self.x, input_pos=self.input_pos)
+ aoti_res = mha_aoti(self.x, self.x, input_pos=self.input_pos)
171
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
172
- self.assertTrue(torch.allclose(et_res, tt_res))
+ assert_close(aoti_res, tt_res)
173
174
def test_attention_executorch(self):
175
# Self attention.
0 commit comments