Skip to content

Commit 13caff0

Browse files
committed
chore: Refactor code
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent b387dd5 commit 13caff0

22 files changed

+71
-1242
lines changed

.circleci/config.yml

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -711,15 +711,15 @@ commands:
711711
# =================== FX tests end ======================== #
712712

713713
# =================== Dynamo tests start ======================== #
714-
test-dynamo-fx_ts_core:
715-
description: "Test the Dynamo core"
714+
test-dynamo-fx_ts:
715+
description: "Test the Dynamo fx_ts_compat path"
716716
steps:
717717
- run:
718-
name: Run Dynamo core tests
718+
name: Run Dynamo fx_ts_compat core tests
719719
command: |
720720
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
721721
pushd core/
722-
pytest --junitxml=/tmp/artifacts/test_results/dynamo/core/test_results.xml
722+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/fx_ts_compat/test_results.xml
723723
popd
724724
725725
- store_test_results:
@@ -737,27 +737,14 @@ commands:
737737
pushd test/
738738
pip3 install timm
739739
pip3 install transformers
740-
pytest --junitxml=/tmp/artifacts/test_results/dynamo/test_results.xml --ir torch_compile
740+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml --ir torch_compile
741741
popd
742742
743743
- store_test_results:
744744
path: /tmp/artifacts
745745
- store_artifacts:
746746
path: /tmp/testlogs
747747

748-
test-dynamo-fx_ts:
749-
description: "Test the dynamo backend"
750-
steps:
751-
- run:
752-
name: Run dynamo tests
753-
command: |
754-
mkdir -p /tmp/artifacts/test_results
755-
- test-dynamo-fx_ts_core
756-
- store_test_results:
757-
path: /tmp/artifacts
758-
- store_artifacts:
759-
path: /tmp/testlogs
760-
761748
# =================== Dynamo tests end ======================== #
762749

763750
# Define a job to be invoked later in a workflow.

py/torch_tensorrt/_Input.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class _ShapeMode(Enum):
4040
DOMAIN_OFFSET = 2.0
4141
low_tensor_domain_incl = 0.0
4242
high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET
43-
torch_dtype = None
43+
torch_dtype = torch.float32
4444

4545
def __init__(self, *args, **kwargs):
4646
"""__init__ Method for torch_tensorrt.Input
@@ -142,6 +142,7 @@ def __init__(self, *args, **kwargs):
142142
self.torch_dtype = kwargs["dtype"]
143143

144144
self.dtype = Input._parse_dtype(kwargs["dtype"])
145+
self.torch_dtype = Input._to_torch_dtype(self.dtype)
145146
self._explicit_set_dtype = True
146147

147148
if "format" in kwargs:
@@ -215,6 +216,22 @@ def _parse_dtype(dtype: Any) -> _enums.dtype:
215216
+ str(type(dtype))
216217
)
217218

219+
@staticmethod
220+
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype:
221+
if dtype == _enums.dtype.long:
222+
return torch.long
223+
elif dtype == _enums.dtype.int32:
224+
return torch.int32
225+
elif dtype == _enums.dtype.half:
226+
return torch.half
227+
elif dtype == _enums.dtype.float:
228+
return torch.float
229+
elif dtype == _enums.dtype.bool:
230+
return torch.bool
231+
else:
232+
# Default torch_dtype used in FX path
233+
return torch.float32
234+
218235
def is_trt_dtype(self) -> bool:
219236
return self.dtype != _enums.dtype.long
220237

@@ -368,9 +385,9 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor
368385

369386
if self.shape_mode == Input._ShapeMode.STATIC:
370387
return torch.rand(self.shape).to(
371-
dtype=self.dtype if not self.torch_dtype else self.torch_dtype
388+
dtype=self.torch_dtype
372389
)
373390
else:
374391
return torch.rand(self.shape[optimization_profile_field]).to(
375-
dtype=self.dtype if not self.torch_dtype else self.torch_dtype
392+
dtype=self.torch_dtype
376393
)

py/torch_tensorrt/dynamo/fx_ts_compat/Dynamic_Shape_Support.md

Lines changed: 0 additions & 137 deletions
This file was deleted.

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS
1717
from .input_tensor_spec import InputTensorSpec
1818
from torch_tensorrt.fx.observer import Observer
19-
from .utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt
19+
from torch_tensorrt.fx.utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt
2020

2121
_LOGGER: logging.Logger = logging.getLogger(__name__)
2222

py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import torch
44

5-
from .types import Shape, ShapeRange
6-
from .utils import get_dynamic_dims
5+
from torch_tensorrt.fx.types import Shape, ShapeRange
6+
from torch_tensorrt.fx.utils import get_dynamic_dims
77
from torch_tensorrt._Input import Input
88

99

py/torch_tensorrt/dynamo/fx_ts_compat/lower.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from .lower_setting import LowerSetting
1515
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder
1616
from .passes.pass_utils import PassFunc, validate_inference
17-
from .tools.timing_cache_utils import TimingCacheManager
18-
from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting
17+
from torch_tensorrt.fx.tools.timing_cache_utils import TimingCacheManager
18+
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
1919

2020
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
2121
from torch_tensorrt.fx.trt_module import TRTModule
22-
from .utils import LowerPrecision
22+
from torch_tensorrt.fx.utils import LowerPrecision
2323
from torch_tensorrt._Device import Device
2424

2525
logger = logging.getLogger(__name__)
@@ -36,12 +36,23 @@ def compile(
3636
enabled_precisions=set(),
3737
min_block_size: int = 3,
3838
workspace_size=0,
39-
verbose_log=False,
39+
dla_sram_size=1048576,
40+
dla_local_dram_size=1073741824,
41+
dla_global_dram_size=536870912,
42+
calibrator=None,
43+
truncate_long_and_double=False,
44+
require_full_compilation=False,
45+
debug=False,
46+
refit=False,
4047
timing_cache_prefix="",
4148
save_timing_cache=False,
4249
cuda_graph_batch_size=-1,
4350
is_aten=False,
4451
use_experimental_fx_rt=False,
52+
num_avg_timing_iters=1,
53+
torch_executed_ops=[],
54+
torch_executed_modules=[],
55+
**kwargs,
4556
) -> nn.Module:
4657
"""
4758
Takes in original module, input and lowering setting, run lowering workflow to turn module
@@ -52,7 +63,7 @@ def compile(
5263
input: Input for module.
5364
min_block_size: Minimal number of nodes for an accelerated submodule
5465
workspace_size: Maximum size of workspace given to TensorRT.
55-
verbose_log: Enable verbose log for TensorRT if set True.
66+
debug: Enable verbose log for TensorRT if set True.
5667
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
5768
save_timing_cache: Update timing cache with current timing cache data if set to True.
5869
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
@@ -65,6 +76,12 @@ def compile(
6576
"The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True"
6677
)
6778

79+
logger.warn(
80+
"For ir=fx_ts_compat backend only the "
81+
+ "following arguments are supported: "
82+
+ "{enabled_precisions, debug, workspace_size, device, disable_tf32, sparse_weights, min_block_size}"
83+
)
84+
6885
# Parse precision into LowerPrecision
6986
lower_precision = LowerPrecision.FP32
7087
if torch.float16 in enabled_precisions:
@@ -100,7 +117,7 @@ def compile(
100117
sparse_weights=sparse_weights,
101118
workspace_size=workspace_size,
102119
lower_precision=lower_precision,
103-
verbose_log=verbose_log,
120+
debug=debug,
104121
timing_cache_prefix=timing_cache_prefix,
105122
save_timing_cache=save_timing_cache,
106123
cuda_graph_batch_size=cuda_graph_batch_size,
@@ -148,7 +165,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
148165
explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
149166
explicit_precision=self.lower_setting.explicit_precision,
150167
logger_level=trt.Logger.VERBOSE
151-
if self.lower_setting.verbose_log
168+
if self.lower_setting.debug
152169
else trt.Logger.WARNING,
153170
)
154171

py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
fuse_permute_linear,
1010
fuse_permute_matmul,
1111
)
12-
from .utils import LowerPrecision
12+
from torch_tensorrt.fx.utils import LowerPrecision
1313

1414

1515
@dc.dataclass
@@ -54,7 +54,7 @@ class LowerSetting(LowerSettingBasic):
5454
as (a->b->c->d)=>(e). Current basic fuse patterns are:
5555
permute->linear
5656
permute->matmul
57-
verbose_log: Enable TensorRT engine verbose log mode.
57+
debug: Enable TensorRT engine verbose log mode.
5858
algo_selector: Enable TensorRT algorithm selector at execution time.
5959
timing_cache_prefix: TensorRT timing cache file path. TensorRT engine will use timing
6060
cache file at execution time if valid timing cache file is provided.
@@ -85,7 +85,7 @@ class LowerSetting(LowerSettingBasic):
8585
[fuse_permute_matmul, fuse_permute_linear]
8686
)
8787
)
88-
verbose_log: bool = False
88+
debug: bool = False
8989
algo_selector = None
9090
timing_cache_prefix: str = ""
9191
save_timing_cache: bool = False

py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch.fx.passes.pass_manager import inplace_wrapper, PassManager
99
from torch.fx.passes.shape_prop import ShapeProp
1010
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult
11-
from torch_tensorrt.dynamo.fx_ts_compat.utils import LowerPrecision
11+
from torch_tensorrt.fx.utils import LowerPrecision
1212
from torch_tensorrt import _Input
1313
from ..input_tensor_spec import InputTensorSpec
1414

py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_import_fx2trt.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

0 commit comments

Comments
 (0)