Skip to content

Commit 18f74b5

Browse files
Dark KnightWei Wei
authored andcommitted
Revert D34929680: Multisect successfully blamed D34929680 for test failures (#74381)
Summary: X-link: pytorch/pytorch#74381 X-link: pytorch/benchmark#808 Pull Request resolved: https://github.com/pytorch/fx2trt/pull/25 Reviewed By: brad-mengchi Differential Revision: D34966585 fbshipit-source-id: a1eea214ba6c9a7c04dd9d327f339bc1c739b0ae
1 parent 9507c6d commit 18f74b5

File tree

6 files changed

+30
-31
lines changed

6 files changed

+30
-31
lines changed

fx/example/lower_example.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
import torchvision
77
from fx2trt_oss.fx import lower_to_trt
8-
from fx2trt_oss.fx.utils import LowerPrecision
98

109

1110
"""
@@ -168,7 +167,7 @@ def run_configuration_benchmark(
168167
time = benchmark_torch_function(conf.batch_iter, lambda: module(*input))
169168
elif not conf.jit:
170169
# Run lowering eager mode benchmark
171-
lowered_module = lower_to_trt(module, input, max_batch_size=conf.batch_size, lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32)
170+
lowered_module = lower_to_trt(module, input, max_batch_size=conf.batch_size, fp16_mode=conf.fp16)
172171
time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))
173172
else:
174173
print("Lowering with JIT is not available!", "red")

fx/example/quantized_resnet_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch.fx
22
import torchvision.models as models
33
from fx2trt_oss.fx import TRTInterpreter, InputTensorSpec, TRTModule
4-
from fx2trt_oss.fx.utils import LowerPrecision
54
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
65
import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer
76
import copy
@@ -17,7 +16,7 @@ def build_fp16_trt(rn18):
1716
rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)])
1817
interp = TRTInterpreter(
1918
rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)])
20-
interpreter_result = interp.run(lower_precision=LowerPrecision.FP16)
19+
interpreter_result = interp.run(fp16_mode=True)
2120
return TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
2221

2322
@torch.no_grad()
@@ -48,7 +47,7 @@ def build_int8_trt(rn18):
4847
[InputTensorSpec(torch.Size([-1, *data.shape[1:]]), torch.float,
4948
shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], has_batch_dim=True)],
5049
explicit_batch_dimension=True, explicit_precision=True, logger_level=trt.Logger.VERBOSE)
51-
interpreter_result = interp.run(lower_precision=LowerPrecision.INT8)
50+
interpreter_result = interp.run(fp16_mode=False, int8_mode=True)
5251
trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
5352
trt_res = trt_mod(data.cuda())
5453
print("explicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
@@ -76,7 +75,7 @@ def build_int8_trt_implicit_quant(rn18):
7675
shape_prop.ShapeProp(traced_rn18).propagate(data)
7776
traced_rn18 = NormalizeArgs(traced_rn18).transform()
7877
interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]), logger_level=trt.Logger.VERBOSE)
79-
interpreter_result = interp.run(lower_precision=LowerPrecision.INT8, strict_type_constraints=True)
78+
interpreter_result = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True)
8079
trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
8180
trt_res = trt_mod(data.cuda())
8281
print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu()))

fx/fx2trt.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .converter_registry import CONVERTERS
1515
from .input_tensor_spec import InputTensorSpec
16-
from .utils import torch_dtype_to_trt, get_dynamic_dims, LowerPrecision
16+
from .utils import torch_dtype_to_trt, get_dynamic_dims
1717

1818
TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
1919

@@ -146,24 +146,27 @@ def run(
146146
self,
147147
max_batch_size=64,
148148
max_workspace_size=1 << 25,
149-
lower_precision=LowerPrecision.FP16,
149+
fp16_mode=True,
150+
int8_mode=False,
150151
sparse_weights=False,
151152
force_fp32_output=False,
152153
strict_type_constraints=False,
153154
algorithm_selector=None,
154155
timing_cache=None,
155156
profiling_verbosity=None,
156157
) -> TRTInterpreterResult:
158+
assert not (fp16_mode and int8_mode), "We cannot enable both fp16 and int8 mode."
159+
157160
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)
158161

159-
# For float outputs, we set their dtype to fp16 only if LowerPrecision.FP16 and
162+
# For float outputs, we set their dtype to fp16 only if fp16_mode=True and
160163
# force_fp32_output=False.
161-
self.output_fp16 = not force_fp32_output and lower_precision == LowerPrecision.FP16
164+
self.output_fp16 = not force_fp32_output and fp16_mode
162165

163-
if lower_precision == LowerPrecision.INT8 and not self.builder.platform_has_fast_int8:
166+
if int8_mode and not self.builder.platform_has_fast_int8:
164167
raise RuntimeError("Current platform doesn't support fast native int8!")
165168

166-
if lower_precision == LowerPrecision.FP16 and not self.builder.platform_has_fast_fp16:
169+
if fp16_mode and not self.builder.platform_has_fast_fp16:
167170
warnings.warn("Current platform doesn't support fast native fp16!")
168171

169172
self.input_specs_iter = 0
@@ -185,10 +188,10 @@ def run(
185188
builder_config.profiling_verbosity = profiling_verbosity \
186189
if profiling_verbosity else \
187190
trt.ProfilingVerbosity.LAYER_NAMES_ONLY
188-
if lower_precision == LowerPrecision.FP16:
191+
if fp16_mode:
189192
builder_config.set_flag(trt.BuilderFlag.FP16)
190193

191-
if lower_precision == LowerPrecision.INT8:
194+
if int8_mode:
192195
builder_config.set_flag(trt.BuilderFlag.INT8)
193196

194197
if sparse_weights:

fx/lower.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from .trt_module import (
3434
TRTModule,
3535
)
36-
from .utils import LowerPrecision
3736

3837

3938
logger = logging.getLogger(__name__)
@@ -80,7 +79,8 @@ def lower_to_trt(
8079
max_batch_size: int = 2048,
8180
max_workspace_size=1 << 25,
8281
explicit_batch_dimension=False,
83-
lower_precision=LowerPrecision.FP16,
82+
fp16_mode=True,
83+
int8_mode=False,
8484
verbose_log=False,
8585
timing_cache_prefix="",
8686
save_timing_cache=False,
@@ -96,7 +96,8 @@ def lower_to_trt(
9696
max_batch_size: Maximum batch size (must be >= 1 to be set, 0 means not set)
9797
max_workspace_size: Maximum size of workspace given to TensorRT.
9898
explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.
99-
lower_precision: lower precision config given to TRTModule. Can select between fp32, fp16 and int8.
99+
fp16_mode: fp16 config given to TRTModule.
100+
int8_mode: int8 config given to TRTModule.
100101
verbose_log: Enable verbose log for TensorRT if set True.
101102
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
102103
save_timing_cache: Update timing cache with current timing cache data if set to True.
@@ -109,7 +110,8 @@ def lower_to_trt(
109110
max_batch_size=max_batch_size,
110111
max_workspace_size=max_workspace_size,
111112
explicit_batch_dimension=explicit_batch_dimension,
112-
lower_precision=lower_precision,
113+
fp16_mode=fp16_mode,
114+
int8_mode=int8_mode,
113115
verbose_log=verbose_log,
114116
timing_cache_prefix=timing_cache_prefix,
115117
save_timing_cache=save_timing_cache,
@@ -135,7 +137,9 @@ class LowerSetting:
135137
136138
explicit_precision: Use explicit precision during lowering.
137139
138-
lower_precision: lower_precision during lowering. Can select between fp32, fp16 and int8.
140+
fp16_mode: Enable FP16 dtype during lowering.
141+
142+
int8_mode: Enable Int8 dtype during lowering.
139143
140144
max_workspace_size: The maximum workspace size. The maximum GPU temporary
141145
memory which the TensorRT engine can use at execution time.
@@ -175,7 +179,8 @@ class LowerSetting:
175179
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
176180
explicit_batch_dimension: bool = True
177181
explicit_precision: bool = False
178-
lower_precision: LowerPrecision = LowerPrecision.FP32
182+
fp16_mode: bool = False
183+
int8_mode: bool = False
179184
max_workspace_size: int = 1 << 30
180185
strict_type_constraints: bool = False
181186
customized_fuse_pass: Sequence = ()
@@ -266,7 +271,8 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
266271
interp_result: TRTInterpreterResult = interpreter.run(
267272
max_batch_size=self.lower_setting.max_batch_size,
268273
max_workspace_size=self.lower_setting.max_workspace_size,
269-
lower_precision=self.lower_setting.lower_precision,
274+
fp16_mode=self.lower_setting.fp16_mode,
275+
int8_mode=self.lower_setting.int8_mode,
270276
strict_type_constraints=self.lower_setting.strict_type_constraints,
271277
algorithm_selector=algo_selector,
272278
timing_cache=cache_data,
@@ -344,7 +350,7 @@ def __call__(
344350
) -> nn.Module:
345351
module.eval()
346352

347-
if self.lower_setting.lower_precision == LowerPrecision.FP16:
353+
if self.lower_setting.fp16_mode:
348354
module.half()
349355
inputs = tuple(x.half() if x.dtype == torch.float32 else x for x in inputs)
350356

fx/utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,6 @@
55
import torch
66

77
from .types import Shape, TRTDataType
8-
from enum import Enum
9-
10-
11-
class LowerPrecision(Enum):
12-
FP32 = "fp32"
13-
FP16 = "fp16"
14-
INT8 = "int8"
158

169

1710
def torch_dtype_to_trt(dtype: torch.dtype) -> TRTDataType:

test/quant/test_quant_trt.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
TRTModule,
1717
)
1818
from fx2trt_oss.fx.lower import run_const_fold
19-
from fx2trt_oss.fx.utils import LowerPrecision
2019
from fx2trt_oss.tracer.acc_tracer import acc_ops
2120
from torch.ao.quantization import default_qconfig
2221
from torch.ao.quantization.quantize_fx import (
@@ -54,7 +53,7 @@ def lower_to_trt(model, inputs, shape_ranges):
5453
model,
5554
input_specs,
5655
explicit_batch_dimension=True, explicit_precision=True)
57-
result = interp.run(lower_precision=LowerPrecision.INT8)
56+
result = interp.run(fp16_mode=False, int8_mode=True)
5857
trt_mod = TRTModule(result.engine, result.input_names, result.output_names)
5958
return trt_mod
6059

0 commit comments

Comments
 (0)