Skip to content

Commit 2399cfb

Browse files
authored
Revert "Torch TRT ngc container changes (#3299)"
This reverts commit 3982401.
1 parent 3982401 commit 2399cfb

File tree

3 files changed

+4
-28
lines changed

3 files changed

+4
-28
lines changed

core/util/Exception.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
#if defined(__GNUC__) && !defined(__clang__)
2-
#if __GNUC__ >= 13
3-
#include <cstdint>
4-
#endif
5-
#elif defined(__clang__)
6-
#if __clang_major__ >= 13
7-
#include <cstdint>
8-
#endif
9-
#endif
10-
111
#include "core/util/Exception.h"
122

133
#include <iostream>

noxfile.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,11 @@ def run_dynamo_runtime_tests(session):
258258
tests = [
259259
"runtime",
260260
]
261-
skip_tests = "-k not hw_compat"
262261
for test in tests:
263262
if USE_HOST_DEPS:
264-
session.run_always("pytest", test, skip_tests, env={"PYTHONPATH": PYT_PATH})
263+
session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH})
265264
else:
266-
session.run_always("pytest", test, skip_tests)
265+
session.run_always("pytest", test)
267266

268267

269268
def run_dynamo_model_compile_tests(session):
@@ -333,6 +332,7 @@ def run_int8_accuracy_tests(session):
333332
tests = [
334333
"ptq/test_ptq_to_backend.py",
335334
"ptq/test_ptq_dataloader_calibrator.py",
335+
"qat/",
336336
]
337337
for test in tests:
338338
if USE_HOST_DEPS:
@@ -473,6 +473,7 @@ def run_l1_int8_accuracy_tests(session):
473473
install_deps(session)
474474
install_torch_trt(session)
475475
train_model(session)
476+
finetune_model(session)
476477
run_int8_accuracy_tests(session)
477478
cleanup(session)
478479

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,10 @@
33

44
import torch
55
import torch_tensorrt
6-
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
76
from torch.testing._internal.common_utils import TestCase, run_tests
87

98
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
109

11-
isSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [
12-
(8, 6),
13-
(8, 7),
14-
(8, 9),
15-
]
16-
1710

1811
class TestInputAsOutput(TestCase):
1912
def test_input_as_output(self):
@@ -286,10 +279,6 @@ def forward(self, q, k, v):
286279
"Test not supported on Windows",
287280
)
288281
class TestLowerFlashAttention(TestCase):
289-
@unittest.skipIf(
290-
not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
291-
"Does not support fused SDPA or not SM86+ hardware",
292-
)
293282
def test_lower_flash_attention(self):
294283
class FlashAttention(torch.nn.Module):
295284
def forward(self, q, k, v):
@@ -359,10 +348,6 @@ def forward(self, q, k, v):
359348
)
360349
torch._dynamo.reset()
361350

362-
@unittest.skipIf(
363-
not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
364-
"Does not support fused SDPA or not SM86+ hardware",
365-
)
366351
def test_flash_attention_converter(self):
367352
class FlashAttention(torch.nn.Module):
368353
def forward(self, q, k, v):

0 commit comments

Comments
 (0)