Skip to content

Commit 2860be6

Browse files
committed
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into export_prototype
2 parents ab76c0d + e19aae7 commit 2860be6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2162
-162
lines changed

core/runtime/execute_engine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
136136
TORCHTRT_CHECK(
137137
inputs[i].dtype() == expected_type,
138138
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
139-
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
139+
auto dims = core::util::toDims(inputs[i].sizes());
140140
auto shape = core::util::toVec(dims);
141141
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
142142
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);

docker/Dockerfile

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,17 @@ COPY --from=torch-tensorrt-builder /workspace/torch_tensorrt/src/dist/ .
106106

107107
RUN cp /opt/torch_tensorrt/docker/WORKSPACE.docker /opt/torch_tensorrt/WORKSPACE
108108
RUN pip install -r /opt/torch_tensorrt/py/requirements.txt
109+
# Install all dependency wheel files and user-specified TensorRT
110+
RUN pip install *.whl
109111
RUN pip install tensorrt==${TENSORRT_VERSION}.*
110-
RUN pip install *.whl && rm -fr /workspace/torch_tensorrt/dist/* *.whl
112+
113+
# Add the Torch-TensorRT wheel file to the dist directory and delete all other .whl files
114+
RUN rm -fr /workspace/torch_tensorrt/dist/*
115+
RUN mkdir -p /opt/torch_tensorrt/dist/ && mv torch_tensorrt*.whl /opt/torch_tensorrt/dist/
116+
RUN rm -fr *.whl
117+
118+
# Remove other cache files if present
119+
RUN pip cache purge && rm -rf /opt/torch_tensorrt/.mypy_cache
111120

112121
WORKDIR /opt/torch_tensorrt
113122

py/torch_tensorrt/_Device.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88

99
import warnings
1010

11-
import torch
12-
from torch_tensorrt import logging
13-
1411
# from torch_tensorrt import _enums
1512
import tensorrt as trt
13+
import torch
14+
from torch_tensorrt import logging
1615

1716
try:
1817
from torch_tensorrt import _C
@@ -120,6 +119,9 @@ def __str__(self) -> str:
120119
)
121120
)
122121

122+
def __repr__(self) -> str:
123+
return self.__str__()
124+
123125
def _to_internal(self) -> _C.Device:
124126
internal_dev = _C.Device()
125127
if self.device_type == trt.DeviceType.GPU:

py/torch_tensorrt/_Input.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,18 +339,18 @@ def from_tensor(
339339
A Input object.
340340
"""
341341
if not (
342-
t.is_contiguous(memory_format=torch.contiguous_format)
342+
disable_memory_format_check
343+
or t.is_contiguous(memory_format=torch.contiguous_format)
343344
or t.is_contiguous(memory_format=torch.channels_last)
344-
or disable_memory_format_check
345345
):
346346
raise ValueError(
347347
"Tensor does not have a supported memory format, supported formats are contiguous or channel_last"
348348
)
349349
frmt = (
350350
torch.contiguous_format
351351
if (
352-
t.is_contiguous(memory_format=torch.contiguous_format)
353-
or disable_memory_format_check
352+
disable_memory_format_check
353+
or t.is_contiguous(memory_format=torch.contiguous_format)
354354
)
355355
else torch.channels_last
356356
)

py/torch_tensorrt/_compile.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,12 @@ def compile(
209209
import collections.abc
210210

211211
from torch_tensorrt import Device
212-
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
212+
from torch_tensorrt.dynamo.utils import prepare_inputs, to_torch_device
213213

214214
if not isinstance(inputs, collections.abc.Sequence):
215215
inputs = [inputs]
216216
device = kwargs.get("device", Device._current_device())
217-
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
217+
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, to_torch_device(device))
218218
module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs)
219219
compiled_aten_module: torch.fx.GraphModule = dynamo_compile(
220220
module,
@@ -239,7 +239,10 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
239239
"""
240240
from torch_tensorrt.dynamo.backend import torch_tensorrt_backend
241241

242-
boxed_fn = torch.compile(module, backend=torch_tensorrt_backend, options={**kwargs})
242+
# TODO: Remove dynamic=False when SymInt Dynamic shape support is ready
243+
boxed_fn = torch.compile(
244+
module, backend=torch_tensorrt_backend, dynamic=False, options={**kwargs}
245+
)
243246

244247
return boxed_fn
245248

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
2+
from torch_tensorrt._Device import Device
23

34
PRECISION = torch.float32
45
DEBUG = False
6+
DEVICE = None
57
WORKSPACE_SIZE = 0
68
MIN_BLOCK_SIZE = 5
79
PASS_THROUGH_BUILD_FAILURES = False
@@ -12,3 +14,7 @@
1214
USE_PYTHON_RUNTIME = False
1315
USE_FAST_PARTITIONER = True
1416
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
17+
18+
19+
def default_device() -> Device:
20+
return Device(gpu_id=torch.cuda.current_device())

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, Set
33

44
import torch
5+
from torch_tensorrt._Device import Device
56
from torch_tensorrt.dynamo._defaults import (
67
DEBUG,
78
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
@@ -15,6 +16,7 @@
1516
USE_PYTHON_RUNTIME,
1617
VERSION_COMPATIBLE,
1718
WORKSPACE_SIZE,
19+
default_device,
1820
)
1921

2022

@@ -54,3 +56,4 @@ class CompilationSettings:
5456
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
5557
use_fast_partitioner: bool = USE_FAST_PARTITIONER
5658
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
59+
device: Device = field(default_factory=default_device)

py/torch_tensorrt/dynamo/compile.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import collections.abc
44
import logging
5-
from typing import Any, List, Optional, Sequence, Set, Tuple
5+
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
66

77
import torch
88
import torch_tensorrt
@@ -13,6 +13,7 @@
1313
from torch_tensorrt.dynamo import CompilationSettings, partitioning
1414
from torch_tensorrt.dynamo._defaults import (
1515
DEBUG,
16+
DEVICE,
1617
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
1718
MAX_AUX_STREAMS,
1819
MIN_BLOCK_SIZE,
@@ -29,7 +30,11 @@
2930
convert_module,
3031
repair_long_or_double_inputs,
3132
)
32-
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
33+
from torch_tensorrt.dynamo.utils import (
34+
prepare_inputs,
35+
to_torch_device,
36+
to_torch_tensorrt_device,
37+
)
3338

3439
logger = logging.getLogger(__name__)
3540

@@ -38,7 +43,7 @@ def compile(
3843
gm: Any,
3944
inputs: Any,
4045
*,
41-
device: Device = Device._current_device(),
46+
device: Optional[Union[Device, torch.device, str]] = DEVICE,
4247
disable_tf32: bool = False,
4348
sparse_weights: bool = False,
4449
enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,),
@@ -82,7 +87,9 @@ def compile(
8287
if not isinstance(inputs, collections.abc.Sequence):
8388
inputs = [inputs]
8489

85-
_, torch_inputs = prepare_inputs(inputs, prepare_device(device))
90+
device = to_torch_tensorrt_device(device)
91+
92+
_, torch_inputs = prepare_inputs(inputs, to_torch_device(device))
8693

8794
if (
8895
torch.float16 in enabled_precisions
@@ -105,6 +112,7 @@ def compile(
105112
compilation_options = {
106113
"precision": precision,
107114
"debug": debug,
115+
"device": device,
108116
"workspace_size": workspace_size,
109117
"min_block_size": min_block_size,
110118
"torch_executed_ops": torch_executed_ops

0 commit comments

Comments
 (0)