Skip to content

Commit 7592715

Browse files
committed
Implemented basic pipeline for Refitting (#2886)
1 parent 41b3928 commit 7592715

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

docsrc/py_api/dynamo.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ Functions
2626

2727
.. autofunction:: refit_module_weights
2828

29-
3029
Classes
3130
--------
3231

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,6 @@ def compile(
176176
else:
177177
make_refitable = kwargs["refit"]
178178

179-
if kwarg_inputs is None:
180-
kwarg_inputs = {}
181179
engine_capability = EngineCapability._from(engine_capability)
182180

183181
if torch_executed_modules is not None and torch_executed_modules:

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from typing import Any, Sequence, Tuple
77

88
import numpy as np
9+
<<<<<<< HEAD
10+
=======
11+
import tensorrt as trt
12+
>>>>>>> 9f46d3940 (Implemented basic pipeline for Refitting (#2886))
913
import torch
1014
from torch.export import ExportedProgram
1115
from torch_tensorrt._enums import dtype
@@ -42,8 +46,11 @@
4246
)
4347
from torch_tensorrt.logging import TRT_LOGGER
4448

49+
<<<<<<< HEAD
4550
import tensorrt as trt
4651

52+
=======
53+
>>>>>>> 9f46d3940 (Implemented basic pipeline for Refitting (#2886))
4754
logger = logging.getLogger(__name__)
4855

4956

@@ -96,16 +103,12 @@ def construct_refit_mapping(
96103
layer_type: str = layer.type.name
97104
if layer_type in MODULE_MAP:
98105
# Cast the parent class to child class to access attributes
99-
# For example: ILayer does not have ILayer.kernel/ILayer.bias
106+
# For example: ILayer does not have ILayer.kernal/ILayer.bias
100107
# So we cast it to IConvolutionLayer and access the attributes
101108
layer.__class__ = MODULE_MAP[layer_type][0]
102109
for weight_type, weight_name in MODULE_MAP[layer_type][1]:
103110
weight = layer.__getattribute__(weight_type).copy()
104-
weight_dtype_opt = dtype.try_from(weight.dtype)
105-
assert (
106-
weight_dtype_opt is not None
107-
), f"Weights {weight_name} has unsupported dtype {weight.dtype}"
108-
weight_dtype = weight_dtype_opt.to(trt.DataType)
111+
weight_dtype = dtype.try_from(weight.dtype).to(trt.DataType)
109112
weight_map[f"{layer.name} {weight_name}"] = (
110113
weight,
111114
weight_dtype,

0 commit comments

Comments
 (0)