Skip to content

Commit c1d2423

Browse files
authored
Cherrypick #3481 and #3445 (#3498)
1 parent b3ef68e commit c1d2423

File tree

11 files changed

+339
-190
lines changed

11 files changed

+339
-190
lines changed

py/torch_tensorrt/_features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
_TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path)
3838
_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
3939
_FX_FE_AVAIL = True
40-
_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.13")
40+
_REFIT_AVAIL = True
4141

4242
ENABLED_FEATURES = FeatureSet(
4343
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 66 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import tensorrt as trt
1010
import torch
1111
from torch.export import ExportedProgram
12+
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1213
from torch_tensorrt._enums import dtype
1314
from torch_tensorrt._features import needs_refit
1415
from torch_tensorrt._Input import Input
@@ -61,26 +62,13 @@ def construct_refit_mapping(
6162
Returns:
6263
Mapping from weight name in TensorRT to actual weight value in np.ndarray
6364
"""
64-
MODULE_MAP = {
65-
"SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]),
66-
"CONVOLUTION": (
67-
trt.IConvolutionLayer,
68-
[("kernel", "KERNEL"), ("bias", "BIAS")],
69-
),
70-
"DECONVOLUTION": (
71-
trt.IDeconvolutionLayer,
72-
[("kernel", "KERNEL"), ("bias", "BIAS")],
73-
),
74-
"CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]),
75-
}
7665

7766
output_dtypes = infer_module_output_dtypes(
7867
module,
7968
truncate_double=settings.truncate_double,
8069
)
8170

8271
# Use Interpreter
83-
weight_map = {}
8472
interpreter = TRTInterpreter(
8573
module,
8674
inputs,
@@ -89,24 +77,8 @@ def construct_refit_mapping(
8977
compilation_settings=settings,
9078
)
9179
interpreter._construct_trt_network_def()
92-
net = interpreter.ctx.net
93-
for i in range(net.num_layers):
94-
layer = net[i]
95-
layer_type: str = layer.type.name
96-
if layer_type in MODULE_MAP:
97-
# Cast the parent class to child class to access attributes
98-
# For example: ILayer does not have ILayer.kernel/ILayer.bias
99-
# So we cast it to IConvolutionLayer and access the attributes
100-
layer.__class__ = MODULE_MAP[layer_type][0]
101-
for weight_type, weight_name in MODULE_MAP[layer_type][1]:
102-
weight = layer.__getattribute__(weight_type).copy()
103-
weight_dtype = dtype.try_from(weight.dtype).to(trt.DataType)
104-
weight_map[f"{layer.name} {weight_name}"] = (
105-
weight,
106-
weight_dtype,
107-
)
10880

109-
return weight_map
81+
return interpreter.ctx.mapping
11082

11183

11284
@needs_refit
@@ -117,13 +89,12 @@ def construct_refit_mapping_from_weight_name_map(
11789
) -> dict[Any, Any]:
11890
engine_weight_map = {}
11991
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
120-
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
121-
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
122-
12392
if sd_weight_name not in state_dict:
12493
# If weights is not in sd, we can leave it unchanged
12594
continue
12695
else:
96+
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
97+
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
12798
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to(
12899
to_torch_device(settings.device)
129100
)
@@ -152,71 +123,73 @@ def _refit_single_trt_engine_with_gm(
152123
Refit a TensorRT Engine in place
153124
"""
154125

155-
refitted = set()
156-
torch_device = get_model_device(new_gm)
157-
refitter = trt.Refitter(old_engine, TRT_LOGGER)
158-
weight_list = refitter.get_all_weights()
159-
160-
if weight_name_map:
161-
# Get the refitting mapping
162-
trt_wt_location = (
163-
trt.TensorLocation.DEVICE
164-
if torch_device.type == "cuda"
165-
else trt.TensorLocation.HOST
166-
)
126+
with unset_fake_temporarily():
127+
refitted = set()
128+
torch_device = get_model_device(new_gm)
129+
refitter = trt.Refitter(old_engine, TRT_LOGGER)
130+
weight_list = refitter.get_all_weights()
131+
132+
if weight_name_map:
133+
# Get the refitting mapping
134+
trt_wt_location = (
135+
trt.TensorLocation.DEVICE
136+
if torch_device.type == "cuda"
137+
else trt.TensorLocation.HOST
138+
)
167139

168-
constant_mapping: dict[str, Any] = weight_name_map.pop(
169-
"constant_mapping", {}
170-
) # type: ignore
171-
mapping = construct_refit_mapping_from_weight_name_map(
172-
weight_name_map, new_gm.state_dict(), settings
173-
)
174-
constant_mapping_with_type = {}
175-
176-
for constant_name, val in constant_mapping.items():
177-
np_weight_type = val.dtype
178-
val_tensor = torch.from_numpy(val).cuda()
179-
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
180-
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
181-
constant_mapping_with_type[constant_name] = (
182-
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
183-
trt_dtype,
140+
constant_mapping: dict[str, Any] = weight_name_map.pop(
141+
"constant_mapping", {}
142+
) # type: ignore
143+
mapping = construct_refit_mapping_from_weight_name_map(
144+
weight_name_map, new_gm.state_dict(), settings
184145
)
146+
constant_mapping_with_type = {}
147+
148+
for constant_name, val in constant_mapping.items():
149+
np_weight_type = val.dtype
150+
val_tensor = torch.from_numpy(val).cuda()
151+
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
152+
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
153+
constant_mapping_with_type[constant_name] = (
154+
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
155+
trt_dtype,
156+
)
185157

186-
mapping.update(constant_mapping_with_type)
158+
mapping.update(constant_mapping_with_type)
187159

188-
for layer_name in weight_list:
189-
if layer_name not in mapping:
190-
logger.warning(f"{layer_name} is not found in weight mapping.")
191-
continue
192-
# Use Numpy to create weights
193-
weight, weight_dtype = mapping[layer_name]
194-
trt_wt_tensor = trt.Weights(
195-
weight_dtype, weight.data_ptr(), torch.numel(weight)
196-
)
197-
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
198-
assert (
199-
len(refitter.get_missing_weights()) == 0
200-
), "Fast refitting failed due to incomplete mapping"
160+
for layer_name in weight_list:
161+
if layer_name not in mapping:
162+
logger.warning(f"{layer_name} is not found in weight mapping.")
163+
continue
164+
# Use Numpy to create weights
165+
weight, weight_dtype = mapping[layer_name]
166+
trt_wt_tensor = trt.Weights(
167+
weight_dtype, weight.data_ptr(), torch.numel(weight)
168+
)
169+
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
170+
assert (
171+
len(refitter.get_missing_weights()) == 0
172+
), "Fast refitting failed due to incomplete mapping"
201173

202-
else:
203-
mapping = construct_refit_mapping(new_gm, input_list, settings)
204-
trt_wt_location = trt.TensorLocation.HOST
205-
for layer_name in weight_list:
206-
if layer_name not in mapping:
207-
raise AssertionError(f"{layer_name} is not found in weight mapping")
208-
# Use Numpy to create weights
209-
weight, datatype = mapping[layer_name]
210-
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
211-
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
212-
refitted.add(layer_name)
213-
214-
if len(refitted) != len(weight_list):
215-
logger.warning("Not all weights have been refitted!!!")
216-
217-
if not refitter.refit_cuda_engine():
218-
logger.error("Error: failed to refit new weights.")
219-
raise AssertionError("Refitting failed.")
174+
else:
175+
mapping = construct_refit_mapping(new_gm, input_list, settings)
176+
trt_wt_location = trt.TensorLocation.HOST
177+
for layer_name in weight_list:
178+
if layer_name not in mapping:
179+
raise AssertionError(f"{layer_name} is not found in weight mapping")
180+
# Use Numpy to create weights
181+
weight = mapping[layer_name]
182+
trt_dtype = dtype._from(weight.dtype).to(trt.DataType)
183+
trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size)
184+
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
185+
refitted.add(layer_name)
186+
187+
if len(refitted) != len(weight_list):
188+
logger.warning("Not all weights have been refitted!!!")
189+
190+
if not refitter.refit_cuda_engine():
191+
logger.error("Error: failed to refit new weights.")
192+
raise AssertionError("Refitting failed.")
220193

221194

222195
@needs_refit

py/torch_tensorrt/dynamo/conversion/_ConversionContext.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import dataclass, field
22

3+
import numpy as np
34
from torch_tensorrt.dynamo._settings import CompilationSettings
45
from torch_tensorrt.fx.types import TRTNetwork
56

@@ -19,3 +20,4 @@ class ConversionContext:
1920
default_factory=CompilationSettings
2021
)
2122
requires_output_allocator: bool = False
23+
mapping: dict[str, np.array] = field(default_factory=dict)

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tensorrt as trt
2222
import torch
2323
import torch.fx
24+
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
2425
from torch.fx.node import _get_qualified_name
2526
from torch.fx.passes.shape_prop import TensorMetadata
2627
from torch.utils._python_dispatch import _disable_current_modes
@@ -42,6 +43,7 @@
4243
get_node_io,
4344
get_node_name,
4445
get_trt_tensor,
46+
to_torch,
4547
)
4648
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
4749
from torch_tensorrt.fx.observer import Observer
@@ -410,27 +412,29 @@ def find_weight(
410412
np_map: the map from weight name to np values in INetworkDefinition
411413
state_dict: state of the graph module
412414
"""
413-
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
414-
for sd_w_name, sd_weight in state_dict.items():
415-
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
416-
del state_dict[sd_w_name]
417-
return sd_w_name
418-
return ""
415+
with unset_fake_temporarily():
416+
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
417+
for sd_w_name, sd_weight in state_dict.items():
418+
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
419+
del state_dict[sd_w_name]
420+
return sd_w_name
421+
return ""
419422

420423
@staticmethod
421424
def check_weight_equal(
422425
sd_weight: torch.tensor,
423426
network_weight: Union[torch.Tensor, np.ndarray],
424427
device: torch.device,
425428
) -> Any:
426-
if not isinstance(network_weight, torch.Tensor):
427-
network_weight = torch.from_numpy(network_weight).to(device)
428-
try:
429-
return sd_weight.shape == network_weight.shape and torch.all(
430-
torch.abs(sd_weight - network_weight) < 0.01
431-
)
432-
except Exception:
433-
return torch.all(sd_weight == network_weight)
429+
with unset_fake_temporarily():
430+
if not isinstance(network_weight, torch.Tensor):
431+
network_weight = torch.from_numpy(network_weight).to(device)
432+
try:
433+
return sd_weight.shape == network_weight.shape and torch.all(
434+
torch.abs(sd_weight - network_weight) < 0.01
435+
)
436+
except Exception:
437+
return torch.all(sd_weight == network_weight)
434438

435439
@needs_refit
436440
def _save_weight_mapping(self) -> None:
@@ -495,19 +499,15 @@ def _save_weight_mapping(self) -> None:
495499
for k, v in self.module.state_dict().items()
496500
}
497501
weight_name_map: dict[str, Any] = {}
498-
np_map = {}
499-
constant_mapping = {}
502+
np_map = self.ctx.mapping
503+
constant_mapping = {k: v for k, v in np_map.items() if v.size == 1}
500504
net = self.ctx.net
501505
for i in range(net.num_layers):
502506
layer = net[i]
503507
layer_type: str = layer.type.name
504508
if layer_type in MODULE_MAP:
505-
layer.__class__ = MODULE_MAP[layer_type][0]
506509
# Name mapping
507510
for weight_type, weight_name, torch_attr in MODULE_MAP[layer_type][1]:
508-
weight = layer.__getattribute__(weight_type).copy()
509-
if weight.size == 0:
510-
continue
511511
engine_weight_name = f"{layer.name} {weight_name}"
512512
# Infer the corresponding weight name(s) in state_dict
513513
sd_weight_name_list = (
@@ -535,17 +535,15 @@ def _save_weight_mapping(self) -> None:
535535
elif "bias" in suffix:
536536
sd_weight_name = f"{sd_weight_name}.bias"
537537
else:
538-
# Save the constant weights for future fast refit
539538
sd_weight_name = f"{sd_weight_name}.unknown"
540-
constant_mapping[engine_weight_name] = weight
541539
elif layer_type == "SCALE":
542540
# Batch norm needs all weights to calculate scale and shift
543541
sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr]
544542
else:
545543
sd_weight_name = f"{sd_weight_name}.{torch_attr}"
546544

547-
weight_name_map[engine_weight_name] = sd_weight_name
548-
np_map[engine_weight_name] = weight
545+
if engine_weight_name in np_map:
546+
weight_name_map[engine_weight_name] = sd_weight_name
549547

550548
# Stage 2: Value mapping
551549
for engine_weight_name, sd_weight_name in weight_name_map.items():
@@ -887,19 +885,15 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
887885
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
888886

889887
def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
890-
with _disable_current_modes():
891-
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
892-
888+
with _disable_current_modes(), unset_fake_temporarily():
893889
frozen_attr = self.fetch_attr(target)
894890

895891
if isinstance(frozen_attr, torch.nn.Parameter):
896892
constant_tensor = frozen_attr.data
897893
else:
898894
constant_tensor = frozen_attr
899895

900-
network_constant = to_numpy(constant_tensor)
901-
902-
return network_constant
896+
return to_torch(constant_tensor)
903897

904898
def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
905899
assert isinstance(target, str)

0 commit comments

Comments
 (0)