Skip to content

Commit 8d6c5aa

Browse files
authored
Merge branch 'main' into dynamo_refactor
2 parents 7cc5eb7 + e884820 commit 8d6c5aa

File tree

11 files changed

+313
-10
lines changed

11 files changed

+313
-10
lines changed

core/lowering/lowering.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "torch/csrc/jit/passes/lower_graph.h"
1010
#include "torch/csrc/jit/passes/lower_tuples.h"
1111
#include "torch/csrc/jit/passes/peephole.h"
12-
#include "torch/csrc/jit/passes/remove_exceptions.h"
1312
#include "torch/csrc/jit/passes/remove_mutation.h"
1413

1514
#include "core/lowering/lowering.h"
@@ -105,7 +104,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
105104
torch::jit::InlineFunctionalGraphs(g);
106105
torch::jit::PeepholeOptimize(g, false);
107106
torch::jit::FuseLinear(g);
108-
torch::jit::EliminateExceptions(g);
107+
passes::EliminateExceptionsSafe(g);
109108
if (!lower_info.disable_cse) {
110109
torch::jit::EliminateCommonSubexpression(g);
111110
}

core/lowering/passes/exception_elimination.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "torch/csrc/jit/ir/alias_analysis.h"
22
#include "torch/csrc/jit/jit_log.h"
3+
#include "torch/csrc/jit/passes/constant_pooling.h"
34
#include "torch/csrc/jit/passes/constant_propagation.h"
45
#include "torch/csrc/jit/passes/dead_code_elimination.h"
56
#include "torch/csrc/jit/passes/guard_elimination.h"
@@ -108,6 +109,71 @@ void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
108109
}
109110
}
110111

112+
/*
113+
Below is a fork of the torch::jit::EliminateExceptions pass, with node replacement
114+
using replaceAllUsesDominatedByNodeWith instead of replaceAllUsesWith,
115+
so as to not invalidate the IR in challenging cases, such as nested Ifs
116+
117+
Original Source from which it was adapted:
118+
https://github.com/pytorch/pytorch/blob/c29ab84115f40614d04e4557ea2e1ac40b7aa75c/torch/csrc/jit/passes/remove_exceptions.cpp
119+
*/
120+
121+
bool certainlyThrows(Block* block) {
122+
// A block certainly throws an exception if it contains
123+
// the prim::RaiseException operation
124+
for (Node* n : block->nodes()) {
125+
if (n->kind() == prim::RaiseException) {
126+
return true;
127+
}
128+
}
129+
return false;
130+
}
131+
132+
void EliminateExceptionsSafe(Block* block) {
133+
auto graph = block->owningGraph();
134+
// Generate false and true constant placeholders
135+
Value* false_const = graph->insertConstant(IValue(false));
136+
Value* true_const = graph->insertConstant(IValue(true));
137+
138+
// For each prim::If node, if either block certainly throws an exception,
139+
// replace input conditional of the node input with the logical opposite
140+
for (Node* n : block->nodes()) {
141+
if (n->kind() == prim::If) {
142+
Block* true_block = n->blocks()[0];
143+
Block* false_block = n->blocks()[1];
144+
bool removed_exception = false;
145+
Value* input_value_replacement;
146+
147+
// If the block throws an exception, replace input with logical opposite
148+
if (certainlyThrows(true_block)) {
149+
removed_exception = true;
150+
input_value_replacement = false_const;
151+
} else if (certainlyThrows(false_block)) {
152+
removed_exception = true;
153+
input_value_replacement = true_const;
154+
}
155+
156+
// Log node and perform input replacement
157+
if (removed_exception) {
158+
LOG_WARNING("Detected and removing exception in TorchScript IR for node: " << util::node_info(n));
159+
n->insertInput(0, input_value_replacement);
160+
n->removeInput(1);
161+
}
162+
}
163+
164+
// Inspect and replace all instances within subblocks of the current node
165+
for (Block* subblock : n->blocks()) {
166+
EliminateExceptionsSafe(subblock);
167+
}
168+
}
169+
}
170+
171+
void EliminateExceptionsSafe(std::shared_ptr<Graph>& graph) {
172+
EliminateExceptionsSafe(graph->block());
173+
ConstantPropagation(graph);
174+
ConstantPooling(graph);
175+
}
176+
111177
} // namespace passes
112178
} // namespace lowering
113179
} // namespace core

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
2020
void ConvTransposed3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
2121
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
2222
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
23+
void EliminateExceptionsSafe(std::shared_ptr<torch::jit::Graph>& graph);
2324
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
2425
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
2526
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph);

py/torch_tensorrt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _find_lib(name, paths):
9494

9595
from torch_tensorrt import fx
9696

97-
if version.parse(torch.__version__) >= version.parse("2.1.dev"):
97+
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
9898
from torch_tensorrt import dynamo
9999
from torch_tensorrt.dynamo import backend
100100

py/torch_tensorrt/_util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,11 @@ def get_build_info() -> str:
3030

3131
def set_device(gpu_id):
3232
_C.set_device(gpu_id)
33+
34+
35+
def sanitized_torch_version() -> str:
36+
return (
37+
torch.__version__
38+
if ".nv" not in torch.__version__
39+
else torch.__version__.split(".nv")[0]
40+
)

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,14 @@ def _pretraced_backend(
8888
)
8989
return gm.forward
9090
else:
91-
raise AssertionError(
91+
logger.critical(
9292
"Halting compilation on build failure since "
9393
+ "pass_through_build_failures was specified as True. "
9494
+ "To return the default Torch implementation and avoid "
9595
+ "halting compilation on engine build failures, "
9696
+ "specify pass_through_build_failures=False."
9797
)
98+
raise
9899

99100

100101
def _compile_module(

py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.testing._internal.common_utils import run_tests, TestCase
1212
from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim
1313
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
14+
from torch_tensorrt._util import sanitized_torch_version
1415

1516
_LOGGER = logging.getLogger(__name__)
1617

@@ -43,7 +44,9 @@ def forward(self, x, y):
4344
%reshape : [num_users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)})
4445
return reshape
4546
"""
46-
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
47+
if version.parse(sanitized_torch_version()) < version.parse(
48+
"2.1.0.dev20230620"
49+
):
4750
expected_graph = expected_graph.replace("num_users", "#users")
4851

4952
assert (

py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import torch.nn as nn
99

1010
import torch_tensorrt.fx.passes.remove_duplicate_output_args as dedup
11+
from torch_tensorrt._util import sanitized_torch_version
12+
1113
from torch.testing._internal.common_utils import run_tests, TestCase
1214

1315
_LOGGER = logging.getLogger(__name__)
@@ -57,7 +59,9 @@ def is_leaf_module(self, m, qn):
5759
return add
5860
""".strip()
5961

60-
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
62+
if version.parse(sanitized_torch_version()) < version.parse(
63+
"2.1.0.dev20230620"
64+
):
6165
ttop_graph_expected = ttop_graph_expected.replace("num_users", "#users")
6266

6367
assert (
@@ -71,7 +75,9 @@ def is_leaf_module(self, m, qn):
7175
return (x,)
7276
""".strip()
7377

74-
if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"):
78+
if version.parse(sanitized_torch_version()) < version.parse(
79+
"2.1.0.dev20230620"
80+
):
7581
ttop_a_graph_expected = ttop_a_graph_expected.replace("num_users", "#users")
7682

7783
assert (

py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from contextlib import contextmanager
44
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
55
from packaging import version
6+
from torch_tensorrt._util import sanitized_torch_version
67

78
import torch
89

9-
if version.parse(torch.__version__) >= version.parse("2.dev"):
10+
if version.parse(sanitized_torch_version()) >= version.parse("2.dev"):
1011
import torch._dynamo as torchdynamo
1112

1213
from torch.fx.passes.infra.pass_base import PassResult

py/torch_tensorrt/fx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
replace_op_with_indices,
1313
run_const_fold,
1414
)
15-
15+
from torch_tensorrt._util import sanitized_torch_version
1616
from .types import Shape, TRTDataType
1717

1818

@@ -160,7 +160,7 @@ def nested_decorator(f: Callable):
160160
def function_wrapper(*args, **kwargs):
161161
# Parse minimum and current Torch versions
162162
min_version = version.parse(min_torch_version)
163-
current_version = version.parse(torch.__version__)
163+
current_version = version.parse(sanitized_torch_version())
164164

165165
if current_version < min_version:
166166
raise AssertionError(

0 commit comments

Comments
 (0)