Skip to content

Commit 12e12cc

Browse files
Mengchi ZhangWei Wei
authored andcommitted
[WIP][fx2trt] Replacing fp16 and int8 mode with enum type (#74338)
Summary: X-link: pytorch/pytorch#74338 Pull Request resolved: https://github.com/pytorch/fx2trt/pull/24 X-link: pytorch/benchmark#805 Reviewed By: jasonjk-park Differential Revision: D34929680 fbshipit-source-id: 8d693ffdfc28d12f5b88aba170da87b192c5e5a2
1 parent 4addf3d commit 12e12cc

File tree

6 files changed

+31
-30
lines changed

6 files changed

+31
-30
lines changed

fx/example/lower_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import torchvision
77
from fx2trt_oss.fx import lower_to_trt
8+
from fx2trt_oss.fx.utils import LowerPrecision
89

910

1011
"""
@@ -167,7 +168,7 @@ def run_configuration_benchmark(
167168
time = benchmark_torch_function(conf.batch_iter, lambda: module(*input))
168169
elif not conf.jit:
169170
# Run lowering eager mode benchmark
170-
lowered_module = lower_to_trt(module, input, max_batch_size=conf.batch_size, fp16_mode=conf.fp16)
171+
lowered_module = lower_to_trt(module, input, max_batch_size=conf.batch_size, lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32)
171172
time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))
172173
else:
173174
print("Lowering with JIT is not available!", "red")

fx/example/quantized_resnet_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch.fx
22
import torchvision.models as models
33
from fx2trt_oss.fx import TRTInterpreter, InputTensorSpec, TRTModule
4+
from fx2trt_oss.fx.utils import LowerPrecision
45
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
56
import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer
67
import copy
@@ -16,7 +17,7 @@ def build_fp16_trt(rn18):
1617
rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)])
1718
interp = TRTInterpreter(
1819
rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)])
19-
interpreter_result = interp.run(fp16_mode=True)
20+
interpreter_result = interp.run(lower_precision=LowerPrecision.FP16)
2021
return TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
2122

2223
@torch.no_grad()
@@ -47,7 +48,7 @@ def build_int8_trt(rn18):
4748
[InputTensorSpec(torch.Size([-1, *data.shape[1:]]), torch.float,
4849
shape_ranges=[((1, 3, 224, 224), (5, 3, 224, 224), (10, 3, 224, 224))], has_batch_dim=True)],
4950
explicit_batch_dimension=True, explicit_precision=True, logger_level=trt.Logger.VERBOSE)
50-
interpreter_result = interp.run(fp16_mode=False, int8_mode=True)
51+
interpreter_result = interp.run(lower_precision=LowerPrecision.INT8)
5152
trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
5253
trt_res = trt_mod(data.cuda())
5354
print("explicit quant result diff max", torch.max(ref_res - trt_res.cpu()))
@@ -75,7 +76,7 @@ def build_int8_trt_implicit_quant(rn18):
7576
shape_prop.ShapeProp(traced_rn18).propagate(data)
7677
traced_rn18 = NormalizeArgs(traced_rn18).transform()
7778
interp = TRTInterpreter(traced_rn18, InputTensorSpec.from_tensors([data]), logger_level=trt.Logger.VERBOSE)
78-
interpreter_result = interp.run(fp16_mode=False, int8_mode=True, strict_type_constraints=True)
79+
interpreter_result = interp.run(lower_precision=LowerPrecision.INT8, strict_type_constraints=True)
7980
trt_mod = TRTModule(interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names)
8081
trt_res = trt_mod(data.cuda())
8182
print("implicit quant result diff max", torch.max(ref_res - trt_res.cpu()))

fx/fx2trt.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .converter_registry import CONVERTERS
1515
from .input_tensor_spec import InputTensorSpec
16-
from .utils import torch_dtype_to_trt, get_dynamic_dims
16+
from .utils import torch_dtype_to_trt, get_dynamic_dims, LowerPrecision
1717

1818
TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
1919

@@ -146,27 +146,24 @@ def run(
146146
self,
147147
max_batch_size=64,
148148
max_workspace_size=1 << 25,
149-
fp16_mode=True,
150-
int8_mode=False,
149+
lower_precision=LowerPrecision.FP16,
151150
sparse_weights=False,
152151
force_fp32_output=False,
153152
strict_type_constraints=False,
154153
algorithm_selector=None,
155154
timing_cache=None,
156155
profiling_verbosity=None,
157156
) -> TRTInterpreterResult:
158-
assert not (fp16_mode and int8_mode), "We cannot enable both fp16 and int8 mode."
159-
160157
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)
161158

162-
# For float outputs, we set their dtype to fp16 only if fp16_mode=True and
159+
# For float outputs, we set their dtype to fp16 only if LowerPrecision.FP16 and
163160
# force_fp32_output=False.
164-
self.output_fp16 = not force_fp32_output and fp16_mode
161+
self.output_fp16 = not force_fp32_output and lower_precision == LowerPrecision.FP16
165162

166-
if int8_mode and not self.builder.platform_has_fast_int8:
163+
if lower_precision == LowerPrecision.INT8 and not self.builder.platform_has_fast_int8:
167164
raise RuntimeError("Current platform doesn't support fast native int8!")
168165

169-
if fp16_mode and not self.builder.platform_has_fast_fp16:
166+
if lower_precision == LowerPrecision.FP16 and not self.builder.platform_has_fast_fp16:
170167
warnings.warn("Current platform doesn't support fast native fp16!")
171168

172169
self.input_specs_iter = 0
@@ -188,10 +185,10 @@ def run(
188185
builder_config.profiling_verbosity = profiling_verbosity \
189186
if profiling_verbosity else \
190187
trt.ProfilingVerbosity.LAYER_NAMES_ONLY
191-
if fp16_mode:
188+
if lower_precision == LowerPrecision.FP16:
192189
builder_config.set_flag(trt.BuilderFlag.FP16)
193190

194-
if int8_mode:
191+
if lower_precision == LowerPrecision.INT8:
195192
builder_config.set_flag(trt.BuilderFlag.INT8)
196193

197194
if sparse_weights:

fx/lower.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .trt_module import (
3434
TRTModule,
3535
)
36+
from .utils import LowerPrecision
3637

3738

3839
logger = logging.getLogger(__name__)
@@ -79,8 +80,7 @@ def lower_to_trt(
7980
max_batch_size: int = 2048,
8081
max_workspace_size=1 << 25,
8182
explicit_batch_dimension=False,
82-
fp16_mode=True,
83-
int8_mode=False,
83+
lower_precision=LowerPrecision.FP16,
8484
verbose_log=False,
8585
timing_cache_prefix="",
8686
save_timing_cache=False,
@@ -96,8 +96,7 @@ def lower_to_trt(
9696
max_batch_size: Maximum batch size (must be >= 1 to be set, 0 means not set)
9797
max_workspace_size: Maximum size of workspace given to TensorRT.
9898
explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.
99-
fp16_mode: fp16 config given to TRTModule.
100-
int8_mode: int8 config given to TRTModule.
99+
lower_precision: lower precision config given to TRTModule. Can select between fp32, fp16 and int8.
101100
verbose_log: Enable verbose log for TensorRT if set True.
102101
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
103102
save_timing_cache: Update timing cache with current timing cache data if set to True.
@@ -110,8 +109,7 @@ def lower_to_trt(
110109
max_batch_size=max_batch_size,
111110
max_workspace_size=max_workspace_size,
112111
explicit_batch_dimension=explicit_batch_dimension,
113-
fp16_mode=fp16_mode,
114-
int8_mode=int8_mode,
112+
lower_precision=lower_precision,
115113
verbose_log=verbose_log,
116114
timing_cache_prefix=timing_cache_prefix,
117115
save_timing_cache=save_timing_cache,
@@ -137,9 +135,7 @@ class LowerSetting:
137135
138136
explicit_precision: Use explicit precision during lowering.
139137
140-
fp16_mode: Enable FP16 dtype during lowering.
141-
142-
int8_mode: Enable Int8 dtype during lowering.
138+
lower_precision: lower_precision during lowering. Can select between fp32, fp16 and int8.
143139
144140
max_workspace_size: The maximum workspace size. The maximum GPU temporary
145141
memory which the TensorRT engine can use at execution time.
@@ -179,8 +175,7 @@ class LowerSetting:
179175
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
180176
explicit_batch_dimension: bool = True
181177
explicit_precision: bool = False
182-
fp16_mode: bool = False
183-
int8_mode: bool = False
178+
lower_precision: LowerPrecision = LowerPrecision.FP32
184179
max_workspace_size: int = 1 << 30
185180
strict_type_constraints: bool = False
186181
customized_fuse_pass: Sequence = ()
@@ -271,8 +266,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
271266
interp_result: TRTInterpreterResult = interpreter.run(
272267
max_batch_size=self.lower_setting.max_batch_size,
273268
max_workspace_size=self.lower_setting.max_workspace_size,
274-
fp16_mode=self.lower_setting.fp16_mode,
275-
int8_mode=self.lower_setting.int8_mode,
269+
lower_precision=self.lower_setting.lower_precision,
276270
strict_type_constraints=self.lower_setting.strict_type_constraints,
277271
algorithm_selector=algo_selector,
278272
timing_cache=cache_data,
@@ -350,7 +344,7 @@ def __call__(
350344
) -> nn.Module:
351345
module.eval()
352346

353-
if self.lower_setting.fp16_mode:
347+
if self.lower_setting.lower_precision == LowerPrecision.FP16:
354348
module.half()
355349
inputs = tuple(x.half() if x.dtype == torch.float32 else x for x in inputs)
356350

fx/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
import torch
66

77
from .types import Shape, TRTDataType
8+
from enum import Enum
9+
10+
11+
class LowerPrecision(Enum):
12+
FP32 = "fp32"
13+
FP16 = "fp16"
14+
INT8 = "int8"
815

916

1017
def torch_dtype_to_trt(dtype: torch.dtype) -> TRTDataType:

test/quant/test_quant_trt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
TRTModule,
1717
)
1818
from fx2trt_oss.fx.lower import run_const_fold
19+
from fx2trt_oss.fx.utils import LowerPrecision
1920
from fx2trt_oss.tracer.acc_tracer import acc_ops
2021
from torch.ao.quantization import default_qconfig
2122
from torch.ao.quantization._quantize_fx_do_not_use import (
@@ -53,7 +54,7 @@ def lower_to_trt(model, inputs, shape_ranges):
5354
model,
5455
input_specs,
5556
explicit_batch_dimension=True, explicit_precision=True)
56-
result = interp.run(fp16_mode=False, int8_mode=True)
57+
result = interp.run(lower_precision=LowerPrecision.INT8)
5758
trt_mod = TRTModule(result.engine, result.input_names, result.output_names)
5859
return trt_mod
5960

0 commit comments

Comments
 (0)