Skip to content

Commit c5d6e16

Browse files
authored
feat: Implement symbolic shape propagation, sym_size converter (#2473)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 69795be commit c5d6e16

File tree

17 files changed

+331
-126
lines changed

17 files changed

+331
-126
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,10 @@ def compile(
189189
)
190190
gm = exported_program.module()
191191
logger.debug("Input graph: " + str(gm.graph))
192-
193192
# Apply lowering on the graph module
194193
torch_inputs = get_torch_inputs(inputs, device)
195194
gm = apply_lowering_passes(gm, torch_inputs)
195+
196196
logger.debug("Lowered Input graph: " + str(gm.graph))
197197

198198
enabled_precisions = set(enabled_precisions)
@@ -308,6 +308,24 @@ def compile_module(
308308
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
309309
)
310310

311+
def contains_metadata(gm: torch.fx.GraphModule) -> bool:
312+
for node in gm.graph.nodes:
313+
if node.op != "output" and (not node.meta) and "val" not in node.meta:
314+
logger.warning(
315+
f"Node {node.name} of op type {node.op} does not have metadata. This could sometimes lead to undefined behavior."
316+
)
317+
return False
318+
return True
319+
320+
# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
321+
if not contains_metadata(gm):
322+
from torch._inductor.compile_fx import fake_tensor_prop
323+
324+
torch_inputs = get_torch_inputs(sample_inputs, settings.device)
325+
with torch.no_grad():
326+
# This fails if the module has data-dependent shape operators.
327+
fake_tensor_prop(gm, torch_inputs)
328+
311329
# Partition module into components that can be TRT-accelerated
312330
fast_partitioner_failed = False
313331

@@ -366,12 +384,7 @@ def compile_module(
366384
)
367385

368386
# Get the submodule inputs for min, opt, max shapes of the graph inputs
369-
submodule_inputs = partitioning.get_submod_inputs(
370-
partitioned_module,
371-
submodule,
372-
sample_inputs,
373-
to_torch_device(settings.device),
374-
)
387+
submodule_inputs = partitioning.construct_submodule_inputs(submodule)
375388

376389
logger.debug(
377390
"Submodule name: %s\n Input shapes: %s\n %s",

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def _pretraced_backend(
7474
fake_mode, "allow_non_fake_inputs", True
7575
), fake_mode:
7676
repair_input_aliasing(gm)
77-
7877
# Invoke AOTAutograd to translate operators to aten
7978
gm = aot_export_joint_simple(
8079
gm,

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,22 @@ def aten_ops_sigmoid(
392392
)
393393

394394

395+
@enforce_tensor_types(
396+
{
397+
0: (TRTTensor,),
398+
}
399+
)
400+
@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int)
401+
def aten_ops_symsize_int(
402+
ctx: ConversionContext,
403+
target: Target,
404+
args: Tuple[Argument, ...],
405+
kwargs: Dict[str, Argument],
406+
name: str,
407+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
408+
return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])
409+
410+
395411
def index_dtype_validator(node: Node) -> bool:
396412
index = node.args[1]
397413
for ind in index:

py/torch_tensorrt/dynamo/conversion/impl/grid.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
from typing import Optional, Sequence
1+
from typing import Optional
22

33
import tensorrt as trt
4-
import torch
54
from torch.fx.node import Target
65
from torch_tensorrt.dynamo._SourceIR import SourceIR
76
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8-
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
97
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
10-
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
8+
from torch_tensorrt.fx.types import TRTTensor
119

1210
# nearest, linear, cubic
1311
GridSamplerInterpolationMode = {

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def index(
9090
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
9191
# If any is not this flag will be set to False
9292
_LOGGER.debug(
93-
f"Determining whether aten.index constant-index optimization can be invoked"
93+
"Determining whether aten.index constant-index optimization can be invoked"
9494
)
9595
is_numpy = all(
9696
isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None
@@ -123,7 +123,7 @@ def index(
123123
return identity_layer.get_output(0)
124124
elif len(tensor_indices) == 1:
125125
indices_tensor = get_trt_tensor(
126-
ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor"
126+
ctx, tensor_indices[0], name + "_parameter_to_fp32_tensor"
127127
)
128128
index = adv_indx_indices[0]
129129
_LOGGER.debug(f"The advanced index indices is {adv_indx_indices}")
@@ -204,7 +204,7 @@ def index(
204204
cum_adv_index = cum_adv_index + adv_index
205205
multiplier = multiplier * input_shape[adv_indx_indices[i]]
206206
cum_adv_index = get_trt_tensor(
207-
ctx, cum_adv_index, name + f"_index_sum_intermediate"
207+
ctx, cum_adv_index, name + "_index_sum_intermediate"
208208
)
209209
else:
210210
multiplier = get_trt_tensor(
@@ -263,7 +263,7 @@ def index(
263263
adv_indx_count
264264
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
265265
):
266-
_LOGGER.debug(f"The indices are continuous in this case")
266+
_LOGGER.debug("The indices are continuous in this case")
267267
concat_tensor_reshape.append(
268268
get_trt_tensor(ctx, -1, name + "_dynamic_concat")
269269
)
@@ -287,7 +287,7 @@ def index(
287287
source_ir,
288288
)
289289
unfold_tensor = regular_index_shuffle_layer.get_output(0)
290-
_LOGGER.debug(f"The tensor is unfolded now")
290+
_LOGGER.debug("The tensor is unfolded now")
291291
_LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}")
292292

293293
# Transpose folded advanced indexed axis to its original location.
@@ -342,7 +342,7 @@ def index(
342342
reshape_output = unfold_advanced_shuffle_layer.get_output(0)
343343

344344
else:
345-
_LOGGER.debug(f"The indices are not continuous in this case")
345+
_LOGGER.debug("The indices are not continuous in this case")
346346
concat_final_tensor = []
347347
concat_final_tensor.append(cum_adv_index_shape_tensor)
348348
for i in range(0, rank):

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,45 @@
88
from torch.fx.node import Target
99
from torch_tensorrt.dynamo._SourceIR import SourceIR
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
11-
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
11+
from torch_tensorrt.dynamo.conversion.converter_utils import (
12+
get_positive_dim,
13+
get_trt_tensor,
14+
to_numpy,
15+
)
1216
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1317
convert_binary_elementwise,
1418
)
1519
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1620
from torch_tensorrt.fx.types import TRTTensor
1721

1822

23+
def shape(
24+
ctx: ConversionContext,
25+
target: Target,
26+
source_ir: Optional[SourceIR],
27+
name: str,
28+
input_val: TRTTensor,
29+
dim: int,
30+
) -> TRTTensor:
31+
"""
32+
This is the general shape layer implementation in TensorRT.
33+
sym_size.int ops map to addShape layer in TensorRT and returns
34+
the dynamic shape of the tensor optionally taking in a dim argument.
35+
"""
36+
shape_layer = ctx.net.add_shape(input_val)
37+
input_shape = shape_layer.get_output(0)
38+
set_layer_name(shape_layer, target, name + "_shape", source_ir)
39+
40+
n_dims = len(input_val.shape)
41+
dim = get_positive_dim(dim, n_dims)
42+
dim_tensor = get_trt_tensor(ctx, dim, name + "_dim")
43+
gather_layer = ctx.net.add_gather(input_shape, dim_tensor, axis=0)
44+
set_layer_name(gather_layer, target, name + "_gather", source_ir)
45+
input_shape = gather_layer.get_output(0)
46+
47+
return input_shape
48+
49+
1950
def get_shape_with_dynamic_shape(
2051
ctx: ConversionContext,
2152
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/shuffle.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch_tensorrt.dynamo.conversion.impl as impl
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
6-
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
77
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
88
from torch_tensorrt.fx.types import TRTTensor
99

@@ -17,7 +17,23 @@ def reshape(
1717
shape: Sequence[int],
1818
) -> TRTTensor:
1919
layer = ctx.net.add_shuffle(input)
20-
layer.reshape_dims = tuple(shape)
20+
if all(isinstance(s, int) for s in shape):
21+
layer.reshape_dims = tuple(shape)
22+
else:
23+
# Convert all the dimensions to trt Tensors.
24+
trt_shape = []
25+
26+
for i, s in enumerate(shape):
27+
if isinstance(s, TRTTensor):
28+
trt_shape.append(s)
29+
else:
30+
a = get_trt_tensor(ctx, s, f"{name}_{i}")
31+
trt_shape.append(a)
32+
shape_layer = ctx.net.add_concatenation(inputs=trt_shape)
33+
shape_layer.axis = 0
34+
shape_layer.name = f"{name}_output_shape"
35+
layer.set_input(1, shape_layer.get_output(0))
36+
2137
set_layer_name(layer, target, name, source_ir)
2238
return layer.get_output(0)
2339

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def expand(
6969
) -> TRTTensor:
7070
shape_rank = len(shape)
7171
initial_tensor_rank = len(input_t.shape)
72-
7372
# If the rank of the input tensor is less than the shape's rank, pad with ones
7473
if initial_tensor_rank < shape_rank:
7574
input_t = prepend_ones(
@@ -99,6 +98,7 @@ def expand(
9998
stride = tuple(
10099
[int(i == o) for i, o in zip(input_tensor_shape, shape)]
101100
) # stride == 1 if dimensions match, 0 otherwise
101+
102102
layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
103103
set_layer_name(layer, target, name, source_ir)
104104
return layer.get_output(0)

py/torch_tensorrt/dynamo/conversion/impl/upsample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def upsample(
2929
resize_layer.scales = [1.0, 1.0] + list(scale_factors)
3030
else:
3131
raise RuntimeError(
32-
f"At least one of out_shape and scale_factors should be specified."
32+
"At least one of out_shape and scale_factors should be specified."
3333
)
3434

3535
# interpolate mode

py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import Any, List
22

33
import torch
44

@@ -29,3 +29,24 @@ def get_tensor_placeholders(
2929
]
3030

3131
return placeholders
32+
33+
34+
def get_metadata(
35+
gm: torch.fx.GraphModule, target_op: Any
36+
) -> List[torch._ops.OpOverload]:
37+
"""
38+
Return the list which has the metadata of all the target_op nodes present in the graph.
39+
"""
40+
return [node.meta for node in gm.graph.nodes if node.target == target_op]
41+
42+
43+
def set_metadata(
44+
gm: torch.fx.GraphModule, target_op: Any, metadata: List[torch._ops.OpOverload]
45+
) -> None:
46+
"""
47+
Return the list which has the metadata of all the target_op nodes present in the graph.
48+
"""
49+
target_nodes = [node for node in gm.graph.nodes if node.target == target_op]
50+
assert len(target_nodes) == len(metadata)
51+
for idx, node in enumerate(target_nodes):
52+
node.meta = metadata[idx]
Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import logging
2-
from typing import Callable, List, Sequence, Tuple
2+
from typing import List, Sequence
33

44
import torch
55
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
66
clean_up_graph_after_modifications,
7+
get_metadata,
8+
set_metadata,
79
)
810

911
logger = logging.getLogger(__name__)
@@ -13,27 +15,25 @@ def view_to_reshape(
1315
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
1416
) -> torch.fx.GraphModule:
1517
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
16-
orig, replacement = view_replacement()
17-
18-
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
19-
gm = clean_up_graph_after_modifications(gm)
20-
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
21-
22-
return gm
23-
24-
25-
def view_replacement() -> Tuple[
26-
torch.fx.GraphModule,
27-
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
28-
]:
29-
"""Constructs the original and replacement functions for view"""
18+
orig_op = torch.ops.aten.view.default
19+
replacement_op = torch.ops.aten.reshape.default
3020

3121
# Original graph
3222
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
33-
return torch.ops.aten.view.default(input, shape)
23+
return orig_op(input, shape)
3424

3525
# Replacement graph
3626
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
37-
return torch.ops.aten.reshape.default(input, shape)
27+
return replacement_op(input, shape)
3828

39-
return orig, replacement
29+
# Store metadata of the orig_op
30+
metadata = get_metadata(gm, orig_op)
31+
32+
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
33+
gm = clean_up_graph_after_modifications(gm)
34+
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
35+
36+
# Copy the orig_op's metadata to the replacement op
37+
set_metadata(gm, replacement_op, metadata)
38+
39+
return gm
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
from ._adjacency_partitioner import partition as fast_partition
22
from ._global_partitioner import partition as global_partition
3-
from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis
3+
from .common import (
4+
construct_submodule_inputs,
5+
get_graph_converter_support,
6+
run_shape_analysis,
7+
)

0 commit comments

Comments
 (0)