Skip to content

Commit bf8af5a

Browse files
authored
Merge pull request #976 from NVIDIA/if_loop_support
If loop support
2 parents 7e404e6 + 6c83a50 commit bf8af5a

File tree

7 files changed

+22
-14
lines changed

7 files changed

+22
-14
lines changed

core/conversion/conversion.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ bool OpSupported(const torch::jit::Node* n) {
2323
return evaluators::shouldEvalAtConversionTime(n) || converters::node_is_convertable(n);
2424
}
2525

26+
bool SpecialCaseSupport(const torch::jit::Node* n) {
27+
return n->kind() == torch::jit::prim::Loop || n->kind() == torch::jit::prim::If;
28+
}
29+
2630
c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::jit::Node* n, int level, int limit) {
2731
// Check to see if you can just go through and eval all of these AOT (saves
2832
// the recursion) Also probably a better way to deal with the two error cases;
@@ -499,7 +503,7 @@ std::unordered_map<c10::OperatorName, std::string> GetUnsupportedOpsInBlock(cons
499503
auto schema = n->maybeSchema();
500504
// Some ops like torch::jit::prim::Loop, torch::jit::prim::If, torch::jit::prim::DictConstruct don't have a schema
501505
// but they are supported. torch::jit::prim::DictConstruct is supported via fallback only
502-
if (!OpSupported(n)) {
506+
if (!OpSupported(n) && !SpecialCaseSupport(n)) {
503507
if (schema) {
504508
std::stringstream ss;
505509
ss << *schema;

core/conversion/evaluators/eval_util.cpp

100755100644
File mode changed.

core/conversion/evaluators/eval_util.h

100755100644
File mode changed.

core/conversion/evaluators/prim.cpp

100755100644
File mode changed.

noxfile.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
# TOP_DIR
1010
TOP_DIR=os.path.dirname(os.path.realpath(__file__)) if not 'TOP_DIR' in os.environ else os.environ["TOP_DIR"]
1111

12-
nox.options.sessions = ["l0_api_tests-3"]
12+
SUPPORTED_PYTHON_VERSIONS=["3.7", "3.8", "3.9", "3.10"]
13+
14+
nox.options.sessions = ["l0_api_tests-3.7"]
1315

1416
def install_deps(session):
1517
print("Installing deps")
@@ -268,62 +270,62 @@ def run_l2_multi_gpu_tests(session, use_host_env=False):
268270
run_multi_gpu_tests(session, use_host_env)
269271
cleanup(session)
270272

271-
@nox.session(python=["3"], reuse_venv=True)
273+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
272274
def l0_api_tests(session):
273275
"""When a developer needs to check correctness for a PR or something"""
274276
run_l0_api_tests(session, use_host_env=False)
275277

276-
@nox.session(python=["3"], reuse_venv=True)
278+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
277279
def l0_api_tests_host_deps(session):
278280
"""When a developer needs to check basic api functionality using host dependencies"""
279281
run_l0_api_tests(session, use_host_env=True)
280282

281-
@nox.session(python=["3"], reuse_venv=True)
283+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
282284
def l0_dla_tests_host_deps(session):
283285
"""When a developer needs to check basic api functionality using host dependencies"""
284286
run_l0_dla_tests(session, use_host_env=True)
285287

286-
@nox.session(python=["3"], reuse_venv=True)
288+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
287289
def l1_accuracy_tests(session):
288290
"""Checking accuracy performance on various usecases"""
289291
run_l1_accuracy_tests(session, use_host_env=False)
290292

291-
@nox.session(python=["3"], reuse_venv=True)
293+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
292294
def l1_accuracy_tests_host_deps(session):
293295
"""Checking accuracy performance on various usecases using host dependencies"""
294296
run_l1_accuracy_tests(session, use_host_env=True)
295297

296-
@nox.session(python=["3"], reuse_venv=True)
298+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
297299
def l1_int8_accuracy_tests(session):
298300
"""Checking accuracy performance on various usecases"""
299301
run_l1_int8_accuracy_tests(session, use_host_env=False)
300302

301-
@nox.session(python=["3"], reuse_venv=True)
303+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
302304
def l1_int8_accuracy_tests_host_deps(session):
303305
"""Checking accuracy performance on various usecases using host dependencies"""
304306
run_l1_int8_accuracy_tests(session, use_host_env=True)
305307

306-
@nox.session(python=["3"], reuse_venv=True)
308+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
307309
def l2_trt_compatibility_tests(session):
308310
"""Makes sure that TensorRT Python and Torch-TensorRT can work together"""
309311
run_l2_trt_compatibility_tests(session, use_host_env=False)
310312

311-
@nox.session(python=["3"], reuse_venv=True)
313+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
312314
def l2_trt_compatibility_tests_host_deps(session):
313315
"""Makes sure that TensorRT Python and Torch-TensorRT can work together using host dependencies"""
314316
run_l2_trt_compatibility_tests(session, use_host_env=True)
315317

316-
@nox.session(python=["3"], reuse_venv=True)
318+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
317319
def l2_multi_gpu_tests(session):
318320
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
319321
run_l2_multi_gpu_tests(session, use_host_env=False)
320322

321-
@nox.session(python=["3"], reuse_venv=True)
323+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
322324
def l2_multi_gpu_tests_host_deps(session):
323325
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems using host dependencies"""
324326
run_l2_multi_gpu_tests(session, use_host_env=True)
325327

326-
@nox.session(python=["3"], reuse_venv=True)
328+
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
327329
def download_test_models(session):
328330
"""Grab all the models needed for testing"""
329331
download_models(session, use_host_env=True)

py/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
-f https://download.pytorch.org/whl/torch_stable.html
2+
-f https://download.pytorch.org/whl/torch/
3+
--extra-index-url https://download.pytorch.org/whl/cu113
24
torch==1.11.0+cu113
35
pybind11==2.6.2

tests/core/lowering/test_reduce_to_pass.cpp

100755100644
File mode changed.

0 commit comments

Comments
 (0)