Skip to content

Commit 5cca652

Browse files
chohk88Hoonkyung Cho
andauthored
fix: cumsum add_constant bug fix (add dtype for np zeros) (#3258)
Co-authored-by: Hoonkyung Cho <[email protected]>
1 parent 7f22d7b commit 5cca652

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def cumsum(
370370
)
371371
else:
372372
new_dims = tuple(data.shape)
373-
zeros = np.zeros(new_dims)
373+
zeros = np.zeros(new_dims, dtype=np.float32)
374374
zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")
375375

376376
running_sum = loop.add_recurrence(zero_trttensor)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import warnings
45
from dataclasses import fields, replace
56
from enum import Enum
67
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
@@ -10,6 +11,8 @@
1011
import tensorrt as trt
1112
import torch
1213
from torch._subclasses.fake_tensor import FakeTensor
14+
15+
from packaging import version
1316
from torch_tensorrt._Device import Device
1417
from torch_tensorrt._enums import dtype
1518
from torch_tensorrt._features import ENABLED_FEATURES
@@ -19,8 +22,6 @@
1922
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
2023
from torch_tensorrt.dynamo._settings import CompilationSettings
2124

22-
from packaging import version
23-
2425
from .types import TRTDataType
2526

2627
logger = logging.getLogger(__name__)
@@ -494,6 +495,27 @@ def parse_dynamo_kwargs(
494495
if "options" in kwargs and len(kwargs) == 1:
495496
kwargs = kwargs["options"]
496497

498+
if "truncate_long_and_double" in kwargs:
499+
if (
500+
"truncate_double" in kwargs
501+
and kwargs["truncate_double"] is not _defaults.TRUNCATE_DOUBLE
502+
):
503+
raise ValueError(
504+
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double". '
505+
'Please only use "truncate_double".'
506+
)
507+
else:
508+
kwargs["truncate_double"] = kwargs["truncate_long_and_double"]
509+
warnings.warn(
510+
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported. '
511+
"This option will be removed in the next version.",
512+
DeprecationWarning,
513+
stacklevel=2,
514+
)
515+
del kwargs[
516+
"truncate_long_and_double"
517+
] # Remove deprecated key after handling
518+
497519
valid_attrs = {attr.name for attr in fields(settings)}
498520
valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs}
499521
settings = replace(settings, **valid_kwargs)

0 commit comments

Comments
 (0)