Skip to content

Commit 47b2902

Browse files
authored
chore: enable DS support for converters (#2775)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 722457b commit 47b2902

29 files changed

+367
-156
lines changed

.github/workflows/build-test-windows.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ jobs:
7272
export USE_HOST_DEPS=1
7373
pushd .
7474
cd tests/py/dynamo
75-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
75+
${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
7676
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/
7777
popd
7878
@@ -98,7 +98,7 @@ jobs:
9898
export USE_HOST_DEPS=1
9999
pushd .
100100
cd tests/py/dynamo
101-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
101+
${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
102102
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
103103
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
104104
popd
@@ -125,7 +125,7 @@ jobs:
125125
export USE_HOST_DEPS=1
126126
pushd .
127127
cd tests/py/dynamo
128-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
128+
${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
129129
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
130130
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
131131
popd
@@ -152,7 +152,7 @@ jobs:
152152
export USE_HOST_DEPS=1
153153
pushd .
154154
cd tests/py/dynamo
155-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
155+
${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
156156
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/
157157
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/
158158
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/

.github/workflows/build-test.yml

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,15 @@ jobs:
7878
script: |
7979
export USE_HOST_DEPS=1
8080
export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH
81-
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
8281
pushd .
8382
cd tests/modules
8483
# Don't use requirements.txt here as it contains tensorrt and torch which should have been installed by now.
85-
${CONDA_RUN} python -m pip install numpy packaging pyyaml transformers timm pybind11==2.6.2
84+
${CONDA_RUN} python -m pip install numpy packaging pyyaml transformers==4.39.3 timm==0.9.16 pybind11==2.6.2
8685
${CONDA_RUN} python hub.py
8786
popd
8887
pushd .
8988
cd tests/py/ts
90-
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
89+
${CONDA_RUN} python -m pip install --pre pytest timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
9190
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_api_test_results.xml api/
9291
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_models_test_results.xml models/
9392
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/ts_integrations_test_results.xml integrations/
@@ -115,10 +114,9 @@ jobs:
115114
pre-script: ${{ matrix.pre-script }}
116115
script: |
117116
export USE_HOST_DEPS=1
118-
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
119117
pushd .
120118
cd tests/py/dynamo
121-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
119+
${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
122120
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/
123121
popd
124122
@@ -144,10 +142,9 @@ jobs:
144142
pre-script: ${{ matrix.pre-script }}
145143
script: |
146144
export USE_HOST_DEPS=1
147-
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
148145
pushd .
149146
cd tests/py/dynamo
150-
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
147+
${CONDA_RUN} python -m pip install --pre pytest timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
151148
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
152149
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
153150
popd
@@ -174,10 +171,9 @@ jobs:
174171
pre-script: ${{ matrix.pre-script }}
175172
script: |
176173
export USE_HOST_DEPS=1
177-
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
178174
pushd .
179175
cd tests/py/dynamo
180-
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
176+
${CONDA_RUN} python -m pip install --pre pytest timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
181177
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
182178
popd
183179
@@ -203,10 +199,9 @@ jobs:
203199
pre-script: ${{ matrix.pre-script }}
204200
script: |
205201
export USE_HOST_DEPS=1
206-
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
207202
pushd .
208203
cd tests/py/dynamo
209-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
204+
${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
210205
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
211206
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
212207
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py
@@ -234,10 +229,9 @@ jobs:
234229
pre-script: ${{ matrix.pre-script }}
235230
script: |
236231
export USE_HOST_DEPS=1
237-
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
238232
pushd .
239233
cd tests/py/dynamo
240-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
234+
${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
241235
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml runtime/
242236
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_partitioning_test_results.xml partitioning/
243237
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_lowering_test_results.xml lowering/
@@ -264,9 +258,8 @@ jobs:
264258
pre-script: ${{ matrix.pre-script }}
265259
script: |
266260
export USE_HOST_DEPS=1
267-
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.1.6/lib:$LD_LIBRARY_PATH
268261
pushd .
269262
cd tests/py/core
270-
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
263+
${CONDA_RUN} python -m pip install --pre pytest-xdist timm==0.9.16 transformers==4.39.3 parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
271264
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml .
272265
popd

core/runtime/execute_engine.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
124124
}
125125
}
126126

127+
// this is a buffer to store shape tensor input addresses throughout the runtime scope
128+
std::list<std::vector<int32_t>> inputShapeTensorValues;
127129
{
128130
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
129131
if (compiled_engine->profile_execution) {
@@ -142,12 +144,30 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
142144
auto dims = core::util::toDims(inputs[i].sizes());
143145
auto shape = core::util::toVec(dims);
144146
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
145-
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);
146-
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr());
147+
if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) {
148+
// Shape tensor inputs are casted to int32 explicitly.
149+
// Refer to
150+
// https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
151+
auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt32);
152+
std::vector<int32_t> inputs_cpu_vec(
153+
input_cpu.data_ptr<int32_t>(), input_cpu.data_ptr<int32_t>() + input_cpu.numel());
154+
inputShapeTensorValues.emplace_back(inputs_cpu_vec);
155+
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data());
156+
} else {
157+
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);
158+
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputs[i].view(shape).contiguous().data_ptr());
159+
}
147160
}
148161

162+
// Check if input shapes can be inferred.
163+
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
164+
std::vector<char const*> names(io_size);
165+
int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data());
149166
TORCHTRT_CHECK(
150-
compiled_engine->exec_ctx->allInputShapesSpecified(), "Not enough inputs provided (runtime.RunCudaEngine)");
167+
nbNames == 0,
168+
"The shapes of the inputs: "
169+
<< names
170+
<< " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly");
151171
}
152172

153173
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);

py/torch_tensorrt/_Input.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class _ShapeMode(Enum):
4747
high_tensor_domain_excl: float = low_tensor_domain_incl + DOMAIN_OFFSET
4848
torch_tensor: torch.Tensor = None
4949
name: str = ""
50+
is_shape_tensor: bool = False
5051

5152
def __init__(self, *args: Any, **kwargs: Any) -> None:
5253
"""__init__ Method for torch_tensorrt.Input
@@ -161,6 +162,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
161162
else:
162163
self._explicit_set_dtype = False
163164

165+
if "is_shape_tensor" in kwargs:
166+
self.is_shape_tensor = kwargs["is_shape_tensor"]
167+
164168
if "format" in kwargs:
165169
self.format = memory_format._from(kwargs["format"])
166170

@@ -174,7 +178,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
174178
if "torch_tensor" in kwargs:
175179
self.torch_tensor = kwargs["torch_tensor"]
176180
else:
177-
if self.shape_mode == Input._ShapeMode.DYNAMIC:
181+
if self.is_shape_tensor:
182+
self.torch_tensor = torch.tensor(
183+
kwargs["opt_shape"], dtype=kwargs["dtype"]
184+
)
185+
elif self.shape_mode == Input._ShapeMode.DYNAMIC:
178186
self.torch_tensor = self.example_tensor("opt_shape")
179187
else:
180188
self.torch_tensor = self.example_tensor()

py/torch_tensorrt/dynamo/_tracer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,9 @@ def trace(
5858

5959
device = to_torch_device(kwargs.get("device", default_device()))
6060
torch_inputs = get_torch_inputs(inputs, device)
61-
dynamic_shapes = {}
61+
dynamic_shapes = []
6262
for input in inputs:
6363
if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC:
64-
if not input.name:
65-
raise AssertionError(
66-
f"Expected a name for a dynamic input with shape {input.shape} but found none"
67-
)
6864
min_shape = input.shape["min_shape"]
6965
opt_shape = input.shape["opt_shape"]
7066
max_shape = input.shape["max_shape"]
@@ -80,8 +76,8 @@ def trace(
8076
max=max_shape[dim],
8177
)
8278

83-
dynamic_shapes[input.name] = dynamic_dims
79+
dynamic_shapes.append(dynamic_dims)
8480

85-
exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=dynamic_shapes)
81+
exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=tuple(dynamic_shapes))
8682

8783
return exp_program

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def _pretraced_backend(
9696

9797
gm = apply_lowering_passes(gm, torch_inputs)
9898

99+
logger.debug("Lowered Input graph:\n " + str(gm.graph))
100+
99101
torchtrt_inputs = prepare_inputs(
100102
torch_inputs, disable_memory_format_check=True
101103
)

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

66
import numpy as np
7+
import tensorrt as trt
78
import torch
89
import torch.fx
910
from torch.fx.node import _get_qualified_name
@@ -22,10 +23,10 @@
2223
get_node_name,
2324
get_trt_tensor,
2425
)
26+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
2527
from torch_tensorrt.fx.observer import Observer
2628
from torch_tensorrt.logging import TRT_LOGGER
2729

28-
import tensorrt as trt
2930
from packaging import version
3031

3132
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -365,18 +366,29 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
365366
max_shape = current_input.shape["max_shape"]
366367
# TODO: Does not support disjoint optimization profiles?
367368
assert self.optimization_profiles is not None
368-
self.optimization_profiles[0].set_shape(
369-
target, min_shape, opt_shape, max_shape
370-
)
371-
372369
assert len(min_shape) == len(opt_shape) == len(max_shape)
373-
for i in range(len(min_shape)):
374-
if min_shape[i] == opt_shape[i] == max_shape[i]:
375-
shape.append(min_shape[i])
376-
else:
377-
# -1 to represent the dynamic dimension
378-
shape.append(-1)
379-
elif current_input.shape_mode == Input._ShapeMode.STATIC:
370+
if current_input.is_shape_tensor:
371+
# For shape_tensors, min/opt/max_shapes correspond to actual values
372+
# of the shapes provided during runtime
373+
self.optimization_profiles[0].set_shape_input(
374+
target, min_shape, opt_shape, max_shape
375+
)
376+
shape.append(len(opt_shape))
377+
else:
378+
self.optimization_profiles[0].set_shape(
379+
target, min_shape, opt_shape, max_shape
380+
)
381+
382+
for i in range(len(min_shape)):
383+
if min_shape[i] == opt_shape[i] == max_shape[i]:
384+
shape.append(min_shape[i])
385+
else:
386+
# -1 to represent the dynamic dimension
387+
shape.append(DYNAMIC_DIM)
388+
elif (
389+
not current_input.is_shape_tensor
390+
and current_input.shape_mode == Input._ShapeMode.STATIC
391+
):
380392
assert isinstance(current_input.shape, tuple)
381393
shape = list(current_input.shape)
382394
else:
@@ -388,6 +400,7 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
388400
_LOGGER.debug(
389401
f"Adding input to in-progress INetwork: {target} [shape={shape}, dtype={trt_input_dtype}]"
390402
)
403+
391404
return self.ctx.net.add_input(
392405
name=target,
393406
shape=tuple(shape),

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import logging
55
from typing import List, Sequence
66

7+
import tensorrt as trt
78
import torch
9+
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
810
from torch_tensorrt._Device import Device
911
from torch_tensorrt._enums import dtype
1012
from torch_tensorrt._features import ENABLED_FEATURES
@@ -17,8 +19,6 @@
1719
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1820
from torch_tensorrt.dynamo.utils import get_torch_inputs
1921

20-
import tensorrt as trt
21-
2222
logger = logging.getLogger(__name__)
2323

2424

@@ -28,12 +28,12 @@ def infer_module_output_dtypes(
2828
device: Device,
2929
truncate_double: bool = False,
3030
) -> List[dtype]:
31-
torch_inputs = get_torch_inputs(inputs, device)
32-
module = module.to(device.to(torch.device))
33-
module_outputs = module(*torch_inputs)
34-
35-
if not isinstance(module_outputs, (list, tuple)):
36-
module_outputs = [module_outputs]
31+
with maybe_disable_fake_tensor_mode():
32+
torch_inputs = get_torch_inputs(inputs, device)
33+
module = module.to(device.to(torch.device))
34+
module_outputs = module(*torch_inputs)
35+
if not isinstance(module_outputs, (list, tuple)):
36+
module_outputs = [module_outputs]
3737

3838
# Int64 outputs can sometimes be generated from within other operators
3939
# such as aten.sum - such outputs can be truncated

0 commit comments

Comments
 (0)