|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import logging
|
| 4 | +import warnings |
4 | 5 | from dataclasses import fields, replace
|
5 | 6 | from enum import Enum
|
6 | 7 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|
10 | 11 | import tensorrt as trt
|
11 | 12 | import torch
|
12 | 13 | from torch._subclasses.fake_tensor import FakeTensor
|
| 14 | + |
| 15 | +from packaging import version |
13 | 16 | from torch_tensorrt._Device import Device
|
14 | 17 | from torch_tensorrt._enums import dtype
|
15 | 18 | from torch_tensorrt._features import ENABLED_FEATURES
|
|
19 | 22 | from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
|
20 | 23 | from torch_tensorrt.dynamo._settings import CompilationSettings
|
21 | 24 |
|
22 |
| -from packaging import version |
23 |
| - |
24 | 25 | from .types import TRTDataType
|
25 | 26 |
|
26 | 27 | logger = logging.getLogger(__name__)
|
@@ -494,6 +495,27 @@ def parse_dynamo_kwargs(
|
494 | 495 | if "options" in kwargs and len(kwargs) == 1:
|
495 | 496 | kwargs = kwargs["options"]
|
496 | 497 |
|
| 498 | + if "truncate_long_and_double" in kwargs: |
| 499 | + if ( |
| 500 | + "truncate_double" in kwargs |
| 501 | + and kwargs["truncate_double"] is not _defaults.TRUNCATE_DOUBLE |
| 502 | + ): |
| 503 | + raise ValueError( |
| 504 | + 'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double". ' |
| 505 | + 'Please only use "truncate_double".' |
| 506 | + ) |
| 507 | + else: |
| 508 | + kwargs["truncate_double"] = kwargs["truncate_long_and_double"] |
| 509 | + warnings.warn( |
| 510 | + 'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported. ' |
| 511 | + "This option will be removed in the next version.", |
| 512 | + DeprecationWarning, |
| 513 | + stacklevel=2, |
| 514 | + ) |
| 515 | + del kwargs[ |
| 516 | + "truncate_long_and_double" |
| 517 | + ] # Remove deprecated key after handling |
| 518 | + |
497 | 519 | valid_attrs = {attr.name for attr in fields(settings)}
|
498 | 520 | valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs}
|
499 | 521 | settings = replace(settings, **valid_kwargs)
|
|
0 commit comments