Skip to content

Commit db1b7af

Browse files
committed
fix: Upgrade to_numpy to allow boolean constants
1 parent 5f02996 commit db1b7af

File tree

7 files changed

+57
-9
lines changed

7 files changed

+57
-9
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
345345

346346
def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
347347
with _disable_current_modes():
348-
from torch_tensorrt.fx.converters import to_numpy
348+
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
349349

350350
frozen_attr = self.fetch_attr(target)
351351

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from torch_tensorrt.fx.converters.converter_utils import (
1818
Frameworks,
1919
get_axes_for_reduce_op,
20-
to_numpy,
2120
unified_dtype_converter,
2221
)
2322
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
@@ -270,9 +269,10 @@ def create_constant(
270269
Returns:
271270
A TensorRT ITensor that represents the given value.
272271
"""
272+
numpy_value = to_numpy(value, dtype)
273273
constant = ctx.net.add_constant(
274274
(1,) if isinstance(value, (int, float, bool)) else value.shape,
275-
to_numpy(value, dtype).copy(),
275+
numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value,
276276
)
277277
constant.name = name
278278
return constant.get_output(0)
@@ -414,3 +414,50 @@ def convert_with_type_enforcement(
414414
return convert_with_type_enforcement
415415

416416
return wrapper
417+
418+
419+
def to_numpy(
420+
value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]],
421+
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None,
422+
) -> Optional[np.ndarray]:
423+
"""
424+
Convert a PyTorch Tensor, Numpy array, or scalar to a Numpy Array. If the tensor is
425+
quantized it will be dequantized first.
426+
Args:
427+
value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]):
428+
A PyTorch tensor, Numpy array, int, float, or bool
429+
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
430+
If a dtype is given, we will convert the type of the given `value` to this dtype.
431+
Returns:
432+
A Numpy array or None, if the input was None.
433+
"""
434+
output = None
435+
436+
if value is None or isinstance(value, np.ndarray):
437+
output = value
438+
439+
elif isinstance(value, torch.Tensor):
440+
if value.is_quantized:
441+
value = value.dequantize()
442+
443+
output = value.cpu().detach().contiguous().numpy()
444+
445+
elif isinstance(value, int):
446+
output = np.array([value], dtype=np.int32)
447+
448+
elif isinstance(value, float):
449+
output = np.array([value], dtype=np.float32)
450+
451+
elif isinstance(value, bool):
452+
output = np.array([value], dtype=np.bool_)
453+
454+
if isinstance(output, np.ndarray) or output is None:
455+
return (
456+
output
457+
if (dtype is None or output is None)
458+
else output.astype(unified_dtype_converter(dtype, Frameworks.NUMPY))
459+
)
460+
else:
461+
raise AssertionError(
462+
f"to_numpy can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got: {value}"
463+
)

py/torch_tensorrt/dynamo/conversion/impl/conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
from torch_tensorrt.dynamo.conversion import impl
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
12+
SourceIR,
1213
extend_attr_to_tuple,
1314
get_trt_tensor,
15+
to_numpy,
1416
)
1517
from torch_tensorrt.fx.converters.converter_utils import (
16-
SourceIR,
1718
get_dyn_range,
1819
has_dynamic_shape,
1920
mark_as_int8_layer,
2021
set_layer_name,
21-
to_numpy,
2222
)
2323
from torch_tensorrt.fx.types import TRTTensor
2424

py/torch_tensorrt/dynamo/conversion/impl/deconv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
1212
extend_attr_to_tuple,
1313
get_trt_tensor,
14+
to_numpy,
1415
)
1516
from torch_tensorrt.fx.converters.converter_utils import (
1617
SourceIR,
1718
get_dyn_range,
1819
has_dynamic_shape,
1920
mark_as_int8_layer,
2021
set_layer_name,
21-
to_numpy,
2222
)
2323
from torch_tensorrt.fx.types import TRTTensor
2424

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.fx.node import Target
88
from torch_tensorrt.dynamo._SourceIR import SourceIR
99
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
10+
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
1011
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1112
convert_binary_elementwise,
1213
)
@@ -16,7 +17,6 @@
1617
get_trt_plugin,
1718
has_dynamic_shape,
1819
set_layer_name,
19-
to_numpy,
2020
)
2121
from torch_tensorrt.fx.types import TRTTensor
2222
from torch_tensorrt.fx.utils import get_dynamic_dims

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
66
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7+
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
78
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
89
from torch_tensorrt.fx.converters.converter_utils import (
910
get_positive_dim,
1011
has_dynamic_shape,
11-
to_numpy,
1212
)
1313
from torch_tensorrt.fx.types import Shape, TRTTensor
1414

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from torch.fx.node import Target
99
from torch_tensorrt.dynamo._SourceIR import SourceIR
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
11+
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
1112
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1213
convert_binary_elementwise,
1314
)
14-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name, to_numpy
15+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1516
from torch_tensorrt.fx.types import TRTTensor
1617

1718

0 commit comments

Comments
 (0)