Skip to content

Commit 9610ba7

Browse files
committed
Key op fixes for failing tests
1 parent d022f4a commit 9610ba7

File tree

4 files changed

+37
-33
lines changed

4 files changed

+37
-33
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, List, Optional, Sequence, Union, cast
33

44
import numpy as np
5+
import tensorrt as trt
56
import torch
67
from torch.fx.node import Target
78
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -19,8 +20,6 @@
1920
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
2021
from torch_tensorrt.fx.utils import get_dynamic_dims
2122

22-
import tensorrt as trt
23-
2423
_LOGGER: logging.Logger = logging.getLogger(__name__)
2524

2625

@@ -101,9 +100,15 @@ def layer_norm(
101100
"of the TensorRT region!"
102101
)
103102

104-
gamma = weight.detach().cpu().float().numpy()
103+
gamma = (
104+
weight.detach().cpu().float().numpy()
105+
if isinstance(weight, torch.Tensor)
106+
else weight
107+
)
105108
gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32)
106-
beta = bias.detach().cpu().float().numpy()
109+
beta = (
110+
bias.detach().cpu().float().numpy() if isinstance(bias, torch.Tensor) else bias
111+
)
107112
beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32)
108113
eps_field = trt.PluginField(
109114
"eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,27 @@
33
import math
44
import operator
55
import warnings
6-
from typing import cast, Dict, Optional, Sequence, Tuple, Union
6+
from typing import Dict, Optional, Sequence, Tuple, Union, cast
77

88
import numpy as np
99

1010
# @manual=//deeplearning/trt/python:py_tensorrt
1111
import tensorrt as trt
1212
import torch
13-
14-
from ..converter_registry import tensorrt_converter
15-
16-
from ..tracer.acc_tracer import acc_ops
17-
from ..types import * # noqa: F403
1813
from torch.fx.immutable_collections import immutable_list
1914
from torch.fx.node import Argument, Target
20-
21-
from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks
22-
23-
from .converter_utils import * # noqa: F403
15+
from torch_tensorrt.fx.converters.impl import activation, convolution
2416
from torch_tensorrt.fx.passes.lower_basic_pass import (
2517
trt_transposed_linear,
2618
trt_transposed_matmul,
2719
)
2820
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
29-
from torch_tensorrt.fx.converters.impl import activation, convolution
21+
22+
from ..converter_registry import tensorrt_converter
23+
from ..tracer.acc_tracer import acc_ops
24+
from ..types import * # noqa: F403
25+
from ..utils import Frameworks, get_dynamic_dims, unified_dtype_converter
26+
from .converter_utils import * # noqa: F403
3027

3128
_LOGGER: logging.Logger = logging.getLogger(__name__)
3229

@@ -2714,8 +2711,14 @@ def acc_ops_linear(
27142711
"dim for linear and it can't be the last dim."
27152712
)
27162713

2717-
if isinstance(kwargs["weight"], torch.Tensor):
2718-
weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight")
2714+
if isinstance(kwargs["weight"], (torch.Tensor, np.ndarray)):
2715+
weight = get_trt_tensor(
2716+
network,
2717+
kwargs["weight"].t()
2718+
if isinstance(kwargs["weight"], torch.Tensor)
2719+
else kwargs["weight"].T,
2720+
f"{name}_weight",
2721+
)
27192722
if target not in (acc_ops.linear, torch.ops.aten.linear):
27202723
weight_op = trt.MatrixOperation.TRANSPOSE
27212724
else:

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import operator
22
import warnings
3+
from enum import Enum, auto
34
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
45

5-
from enum import Enum, auto
66
import numpy as np
77

88
# @manual=//deeplearning/trt/python:py_tensorrt
@@ -20,7 +20,7 @@
2020
TRTPluginFieldCollection,
2121
TRTTensor,
2222
)
23-
from ..utils import unified_dtype_converter, Frameworks
23+
from ..utils import Frameworks, unified_dtype_converter
2424

2525

2626
class SourceIR(Enum):
@@ -271,7 +271,7 @@ def create_constant(
271271
"""
272272
constant = network.add_constant(
273273
(1,) if isinstance(value, (int, float)) else value.shape,
274-
to_numpy(value, dtype),
274+
to_numpy(value, dtype).copy(),
275275
)
276276
constant.name = name
277277
return constant.get_output(0)
@@ -311,7 +311,7 @@ def get_trt_tensor(
311311
elif isinstance(input_val, np.ndarray) and (
312312
input_val.dtype == np.bool_ or input_val.dtype == np.int64
313313
):
314-
input_val = input_val.to(np.int32)
314+
input_val = input_val.astype(np.int32)
315315

316316
if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)):
317317
return create_constant(network, input_val, name, dtype)

py/torch_tensorrt/fx/converters/impl/convolution.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,23 @@
1-
import numpy as np
21
from typing import Any, Optional, Sequence, Union
32

3+
import numpy as np
4+
45
# @manual=//deeplearning/trt/python:py_tensorrt
56
import tensorrt as trt
67
import torch
78
from torch.fx.node import Target
8-
9+
from torch_tensorrt.fx.converters import acc_ops_converters
910
from torch_tensorrt.fx.converters.converter_utils import (
1011
SourceIR,
1112
extend_attr_to_tuple,
1213
get_dyn_range,
14+
get_trt_tensor,
15+
has_dynamic_shape,
1316
mark_as_int8_layer,
1417
set_layer_name,
15-
has_dynamic_shape,
1618
to_numpy,
17-
get_trt_tensor,
18-
)
19-
from torch_tensorrt.fx.converters import acc_ops_converters
20-
21-
from torch_tensorrt.fx.types import (
22-
TRTNetwork,
23-
TRTTensor,
2419
)
20+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
2521

2622

2723
def convNd(
@@ -54,7 +50,7 @@ def convNd(
5450
)
5551

5652
# Process bias terms
57-
if isinstance(bias, torch.Tensor):
53+
if isinstance(bias, (torch.Tensor, np.ndarray)):
5854
# Transform the bias constant into a Numpy array
5955
bias = to_numpy(bias)
6056

@@ -79,7 +75,7 @@ def convNd(
7975
network, target, tuple(), kwargs, name + "_unsqueeze_weight"
8076
)
8177

82-
elif isinstance(weight, torch.Tensor):
78+
elif isinstance(weight, (torch.Tensor, np.ndarray)):
8379
# Transform the weight constant into a Numpy array
8480
weight = to_numpy(weight)
8581

0 commit comments

Comments
 (0)