Skip to content

Commit 9b6f287

Browse files
peri044apbose
authored andcommitted
chore: patch test_hw_compatibility and restrict flash attention test to sm >=86 (#2631)
1 parent 3e8d735 commit 9b6f287

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

noxfile.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,12 @@ def run_dynamo_runtime_tests(session):
258258
tests = [
259259
"runtime",
260260
]
261+
skip_tests = "-k not hw_compat"
261262
for test in tests:
262263
if USE_HOST_DEPS:
263-
session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH})
264+
session.run_always("pytest", test, skip_tests, env={"PYTHONPATH": PYT_PATH})
264265
else:
265-
session.run_always("pytest", test)
266+
session.run_always("pytest", test, skip_tests)
266267

267268

268269
def run_dynamo_model_compile_tests(session):

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@
33

44
import torch
55
import torch_tensorrt
6+
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
67
from torch.testing._internal.common_utils import TestCase, run_tests
78

89
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
910

11+
isSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [
12+
(8, 6),
13+
(8, 7),
14+
(8, 9),
15+
]
16+
1017

1118
class TestInputAsOutput(TestCase):
1219
def test_input_as_output(self):
@@ -279,6 +286,10 @@ def forward(self, q, k, v):
279286
"Test not supported on Windows",
280287
)
281288
class TestLowerFlashAttention(TestCase):
289+
@unittest.skipIf(
290+
not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
291+
"Does not support fused SDPA or not SM86+ hardware",
292+
)
282293
def test_lower_flash_attention(self):
283294
class FlashAttention(torch.nn.Module):
284295
def forward(self, q, k, v):
@@ -348,6 +359,10 @@ def forward(self, q, k, v):
348359
)
349360
torch._dynamo.reset()
350361

362+
@unittest.skipIf(
363+
not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
364+
"Does not support fused SDPA or not SM86+ hardware",
365+
)
351366
def test_flash_attention_converter(self):
352367
class FlashAttention(torch.nn.Module):
353368
def forward(self, q, k, v):

0 commit comments

Comments
 (0)