|
2 | 2 |
|
3 | 3 | import torch
|
4 | 4 | import torch_tensorrt
|
| 5 | +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION |
5 | 6 | from torch.testing._internal.common_utils import TestCase, run_tests
|
6 | 7 |
|
7 | 8 | from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
|
8 | 9 |
|
| 10 | +isSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [ |
| 11 | + (8, 6), |
| 12 | + (8, 7), |
| 13 | + (8, 9), |
| 14 | +] |
| 15 | + |
9 | 16 |
|
10 | 17 | class TestInputAsOutput(TestCase):
|
11 | 18 | def test_input_as_output(self):
|
@@ -274,6 +281,10 @@ def forward(self, q, k, v):
|
274 | 281 | "GPU compute capability is too low to run flash attention, need Ampere (8.0) or greater",
|
275 | 282 | )
|
276 | 283 | class TestLowerFlashAttention(TestCase):
|
| 284 | + @unittest.skipIf( |
| 285 | + not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice, |
| 286 | + "Does not support fused SDPA or not SM86+ hardware", |
| 287 | + ) |
277 | 288 | def test_lower_flash_attention(self):
|
278 | 289 | class FlashAttention(torch.nn.Module):
|
279 | 290 | def forward(self, q, k, v):
|
@@ -343,6 +354,10 @@ def forward(self, q, k, v):
|
343 | 354 | )
|
344 | 355 | torch._dynamo.reset()
|
345 | 356 |
|
| 357 | + @unittest.skipIf( |
| 358 | + not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice, |
| 359 | + "Does not support fused SDPA or not SM86+ hardware", |
| 360 | + ) |
346 | 361 | def test_flash_attention_converter(self):
|
347 | 362 | class FlashAttention(torch.nn.Module):
|
348 | 363 | def forward(self, q, k, v):
|
|
0 commit comments