Skip to content

Commit 3982401

Browse files
apboseperi044
andauthored
Torch TRT ngc container changes (#3299)
Co-authored-by: Dheeraj Peri <[email protected]>
1 parent b4bc713 commit 3982401

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

core/util/Exception.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
#if defined(__GNUC__) && !defined(__clang__)
2+
#if __GNUC__ >= 13
3+
#include <cstdint>
4+
#endif
5+
#elif defined(__clang__)
6+
#if __clang_major__ >= 13
7+
#include <cstdint>
8+
#endif
9+
#endif
10+
111
#include "core/util/Exception.h"
212

313
#include <iostream>

noxfile.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,12 @@ def run_dynamo_runtime_tests(session):
258258
tests = [
259259
"runtime",
260260
]
261+
skip_tests = "-k not hw_compat"
261262
for test in tests:
262263
if USE_HOST_DEPS:
263-
session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH})
264+
session.run_always("pytest", test, skip_tests, env={"PYTHONPATH": PYT_PATH})
264265
else:
265-
session.run_always("pytest", test)
266+
session.run_always("pytest", test, skip_tests)
266267

267268

268269
def run_dynamo_model_compile_tests(session):
@@ -332,7 +333,6 @@ def run_int8_accuracy_tests(session):
332333
tests = [
333334
"ptq/test_ptq_to_backend.py",
334335
"ptq/test_ptq_dataloader_calibrator.py",
335-
"qat/",
336336
]
337337
for test in tests:
338338
if USE_HOST_DEPS:
@@ -473,7 +473,6 @@ def run_l1_int8_accuracy_tests(session):
473473
install_deps(session)
474474
install_torch_trt(session)
475475
train_model(session)
476-
finetune_model(session)
477476
run_int8_accuracy_tests(session)
478477
cleanup(session)
479478

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@
33

44
import torch
55
import torch_tensorrt
6+
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
67
from torch.testing._internal.common_utils import TestCase, run_tests
78

89
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
910

11+
isSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [
12+
(8, 6),
13+
(8, 7),
14+
(8, 9),
15+
]
16+
1017

1118
class TestInputAsOutput(TestCase):
1219
def test_input_as_output(self):
@@ -279,6 +286,10 @@ def forward(self, q, k, v):
279286
"Test not supported on Windows",
280287
)
281288
class TestLowerFlashAttention(TestCase):
289+
@unittest.skipIf(
290+
not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
291+
"Does not support fused SDPA or not SM86+ hardware",
292+
)
282293
def test_lower_flash_attention(self):
283294
class FlashAttention(torch.nn.Module):
284295
def forward(self, q, k, v):
@@ -348,6 +359,10 @@ def forward(self, q, k, v):
348359
)
349360
torch._dynamo.reset()
350361

362+
@unittest.skipIf(
363+
not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
364+
"Does not support fused SDPA or not SM86+ hardware",
365+
)
351366
def test_flash_attention_converter(self):
352367
class FlashAttention(torch.nn.Module):
353368
def forward(self, q, k, v):

0 commit comments

Comments
 (0)