Skip to content

Commit 16088e6

Browse files
committed
chore: updates
2 parents d16585f + c5d6e16 commit 16088e6

File tree

15 files changed

+323
-118
lines changed

15 files changed

+323
-118
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,10 @@ def compile(
189189
)
190190
gm = exported_program.module()
191191
logger.debug("Input graph: " + str(gm.graph))
192-
193192
# Apply lowering on the graph module
194193
torch_inputs = get_torch_inputs(inputs, device)
195194
gm = apply_lowering_passes(gm, torch_inputs)
195+
196196
logger.debug("Lowered Input graph: " + str(gm.graph))
197197

198198
enabled_precisions = set(enabled_precisions)
@@ -308,6 +308,24 @@ def compile_module(
308308
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
309309
)
310310

311+
def contains_metadata(gm: torch.fx.GraphModule) -> bool:
312+
for node in gm.graph.nodes:
313+
if node.op != "output" and (not node.meta) and "val" not in node.meta:
314+
logger.warning(
315+
f"Node {node.name} of op type {node.op} does not have metadata. This could sometimes lead to undefined behavior."
316+
)
317+
return False
318+
return True
319+
320+
# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
321+
if not contains_metadata(gm):
322+
from torch._inductor.compile_fx import fake_tensor_prop
323+
324+
torch_inputs = get_torch_inputs(sample_inputs, settings.device)
325+
with torch.no_grad():
326+
# This fails if the module has data-dependent shape operators.
327+
fake_tensor_prop(gm, torch_inputs)
328+
311329
# Partition module into components that can be TRT-accelerated
312330
fast_partitioner_failed = False
313331

@@ -366,12 +384,7 @@ def compile_module(
366384
)
367385

368386
# Get the submodule inputs for min, opt, max shapes of the graph inputs
369-
submodule_inputs = partitioning.get_submod_inputs(
370-
partitioned_module,
371-
submodule,
372-
sample_inputs,
373-
to_torch_device(settings.device),
374-
)
387+
submodule_inputs = partitioning.construct_submodule_inputs(submodule)
375388

376389
logger.debug(
377390
"Submodule name: %s\n Input shapes: %s\n %s",

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def _pretraced_backend(
7474
fake_mode, "allow_non_fake_inputs", True
7575
), fake_mode:
7676
repair_input_aliasing(gm)
77-
7877
# Invoke AOTAutograd to translate operators to aten
7978
gm = aot_export_joint_simple(
8079
gm,

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,22 @@ def aten_ops_sigmoid(
392392
)
393393

394394

395+
@enforce_tensor_types(
396+
{
397+
0: (TRTTensor,),
398+
}
399+
)
400+
@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)
401+
def aten_ops_symsize_int(
402+
ctx: ConversionContext,
403+
target: Target,
404+
args: Tuple[Argument, ...],
405+
kwargs: Dict[str, Argument],
406+
name: str,
407+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
408+
return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])
409+
410+
395411
def index_dtype_validator(node: Node) -> bool:
396412
index = node.args[1]
397413
for ind in index:

py/torch_tensorrt/dynamo/conversion/impl/grid.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
from typing import Optional, Sequence
1+
from typing import Optional
22

33
import tensorrt as trt
4-
import torch
54
from torch.fx.node import Target
65
from torch_tensorrt.dynamo._SourceIR import SourceIR
76
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8-
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
97
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
10-
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
8+
from torch_tensorrt.fx.types import TRTTensor
119

1210
# nearest, linear, cubic
1311
GridSamplerInterpolationMode = {

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from torch.fx.node import Target
99
from torch_tensorrt.dynamo._SourceIR import SourceIR
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
11+
from torch_tensorrt.dynamo.conversion.converter_utils import (
12+
get_positive_dim,
13+
get_trt_tensor,
14+
)
1115
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1216
convert_binary_elementwise,
1317
)
@@ -19,6 +23,33 @@
1923
from torch_tensorrt.fx.types import TRTTensor
2024

2125

26+
def shape(
27+
ctx: ConversionContext,
28+
target: Target,
29+
source_ir: Optional[SourceIR],
30+
name: str,
31+
input_val: TRTTensor,
32+
dim: int,
33+
) -> TRTTensor:
34+
"""
35+
This is the general shape layer implementation in TensorRT.
36+
sym_size.int ops map to addShape layer in TensorRT and returns
37+
the dynamic shape of the tensor optionally taking in a dim argument.
38+
"""
39+
shape_layer = ctx.net.add_shape(input_val)
40+
input_shape = shape_layer.get_output(0)
41+
set_layer_name(shape_layer, target, name + "_shape", source_ir)
42+
43+
n_dims = len(input_val.shape)
44+
dim = get_positive_dim(dim, n_dims)
45+
dim_tensor = get_trt_tensor(ctx, dim, name + "_dim")
46+
gather_layer = ctx.net.add_gather(input_shape, dim_tensor, axis=0)
47+
set_layer_name(gather_layer, target, name + "_gather", source_ir)
48+
input_shape = gather_layer.get_output(0)
49+
50+
return input_shape
51+
52+
2253
def get_shape_with_dynamic_shape(
2354
ctx: ConversionContext,
2455
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch_tensorrt.dynamo.conversion.impl as impl
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
6-
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
77
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
88
from torch_tensorrt.fx.types import TRTTensor
99

@@ -17,7 +17,23 @@ def reshape(
1717
shape: Sequence[int],
1818
) -> TRTTensor:
1919
layer = ctx.net.add_shuffle(input)
20-
layer.reshape_dims = tuple(shape)
20+
if all(isinstance(s, int) for s in shape):
21+
layer.reshape_dims = tuple(shape)
22+
else:
23+
# Convert all the dimensions to trt Tensors.
24+
trt_shape = []
25+
26+
for i, s in enumerate(shape):
27+
if isinstance(s, TRTTensor):
28+
trt_shape.append(s)
29+
else:
30+
a = get_trt_tensor(ctx, s, f"{name}_{i}")
31+
trt_shape.append(a)
32+
shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
33+
shape_layer.axis = 0
34+
shape_layer.name = f"{name}_output_shape"
35+
layer.set_input(1, shape_layer.get_output(0))
36+
2137
set_layer_name(layer, target, name, source_ir)
2238
return layer.get_output(0)
2339

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def expand(
6969
) -> TRTTensor:
7070
shape_rank = len(shape)
7171
initial_tensor_rank = len(input_t.shape)
72-
7372
# If the rank of the input tensor is less than the shape's rank, pad with ones
7473
if initial_tensor_rank < shape_rank:
7574
input_t = prepend_ones(
@@ -99,6 +98,7 @@ def expand(
9998
stride = tuple(
10099
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
101100
) # stride == 1 if dimensions match, 0 otherwise
101+
102102
layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
103103
set_layer_name(layer, target, name, source_ir)
104104
return layer.get_output(0)

py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import Any, List
22

33
import torch
44

@@ -29,3 +29,24 @@ def get_tensor_placeholders(
2929
]
3030

3131
return placeholders
32+
33+
34+
def get_metadata(
35+
gm: torch.fx.GraphModule, target_op: Any
36+
) -> List[torch._ops.OpOverload]:
37+
"""
38+
Return the list which has the metadata of all the target_op nodes present in the graph.
39+
"""
40+
return [node.meta for node in gm.graph.nodes if node.target == target_op]
41+
42+
43+
def set_metadata(
44+
gm: torch.fx.GraphModule, target_op: Any, metadata: List[torch._ops.OpOverload]
45+
) -> None:
46+
"""
47+
Return the list which has the metadata of all the target_op nodes present in the graph.
48+
"""
49+
target_nodes = [node for node in gm.graph.nodes if node.target == target_op]
50+
assert len(target_nodes) == len(metadata)
51+
for idx, node in enumerate(target_nodes):
52+
node.meta = metadata[idx]
Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import logging
2-
from typing import Callable, List, Sequence, Tuple
2+
from typing import List, Sequence
33

44
import torch
55
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
66
clean_up_graph_after_modifications,
7+
get_metadata,
8+
set_metadata,
79
)
810

911
logger = logging.getLogger(__name__)
@@ -13,27 +15,25 @@ def view_to_reshape(
1315
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
1416
) -> torch.fx.GraphModule:
1517
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
16-
orig, replacement = view_replacement()
17-
18-
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
19-
gm = clean_up_graph_after_modifications(gm)
20-
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
21-
22-
return gm
23-
24-
25-
def view_replacement() -> Tuple[
26-
torch.fx.GraphModule,
27-
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
28-
]:
29-
"""Constructs the original and replacement functions for view"""
18+
orig_op = torch.ops.aten.view.default
19+
replacement_op = torch.ops.aten.reshape.default
3020

3121
# Original graph
3222
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
33-
return torch.ops.aten.view.default(input, shape)
23+
return orig_op(input, shape)
3424

3525
# Replacement graph
3626
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
37-
return torch.ops.aten.reshape.default(input, shape)
27+
return replacement_op(input, shape)
3828

39-
return orig, replacement
29+
# Store metadata of the orig_op
30+
metadata = get_metadata(gm, orig_op)
31+
32+
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
33+
gm = clean_up_graph_after_modifications(gm)
34+
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
35+
36+
# Copy the orig_op's metadata to the replacement op
37+
set_metadata(gm, replacement_op, metadata)
38+
39+
return gm
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
from ._adjacency_partitioner import partition as fast_partition
22
from ._global_partitioner import partition as global_partition
3-
from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis
3+
from .common import (
4+
construct_submodule_inputs,
5+
get_graph_converter_support,
6+
run_shape_analysis,
7+
)

0 commit comments

Comments
 (0)