Skip to content

Commit 489e22d

Browse files
peri044apbose
authored andcommitted
chore: patch test_hw_compatibility and restrict flash attention test to sm >=86 (#2631)
1 parent bbe0af4 commit 489e22d

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

noxfile.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,12 @@ def run_dynamo_runtime_tests(session):
258258
tests = [
259259
"runtime",
260260
]
261+
skip_tests = "-k not hw_compat"
261262
for test in tests:
262263
if USE_HOST_DEPS:
263-
session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH})
264+
session.run_always("pytest", test, skip_tests, env={"PYTHONPATH": PYT_PATH})
264265
else:
265-
session.run_always("pytest", test)
266+
session.run_always("pytest", test, skip_tests)
266267

267268

268269
def run_dynamo_model_compile_tests(session):

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22

33
import torch
44
import torch_tensorrt
5+
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
56
from torch.testing._internal.common_utils import TestCase, run_tests
67

78
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
89

10+
isSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [
11+
(8, 6),
12+
(8, 7),
13+
(8, 9),
14+
]
15+
916

1017
class TestInputAsOutput(TestCase):
1118
def test_input_as_output(self):
@@ -274,6 +281,10 @@ def forward(self, q, k, v):
274281
"GPU compute capability is too low to run flash attention, need Ampere (8.0) or greater",
275282
)
276283
class TestLowerFlashAttention(TestCase):
284+
@unittest.skipIf(
285+
not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
286+
"Does not support fused SDPA or not SM86+ hardware",
287+
)
277288
def test_lower_flash_attention(self):
278289
class FlashAttention(torch.nn.Module):
279290
def forward(self, q, k, v):
@@ -343,6 +354,10 @@ def forward(self, q, k, v):
343354
)
344355
torch._dynamo.reset()
345356

357+
@unittest.skipIf(
358+
not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
359+
"Does not support fused SDPA or not SM86+ hardware",
360+
)
346361
def test_flash_attention_converter(self):
347362
class FlashAttention(torch.nn.Module):
348363
def forward(self, q, k, v):

0 commit comments

Comments
 (0)