Skip to content

Commit 0d2cd71

Browse files
Mengchi ZhangWei Wei
authored andcommitted
[fx2trt] Enable int8 in lower_to_trt (#21)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/21 Reviewed By: jasonjk-park, yinghai Differential Revision: D34916991 fbshipit-source-id: c088b4d6fe40444e13433a6eac76bcbd0fa078e6
1 parent 6a7f4db commit 0d2cd71

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

fx/fx2trt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def run(
155155
timing_cache=None,
156156
profiling_verbosity=None,
157157
) -> TRTInterpreterResult:
158+
assert not (fp16_mode and int8_mode), "We cannot enable both fp16 and int8 mode."
159+
158160
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)
159161

160162
# For float outputs, we set their dtype to fp16 only if fp16_mode=True and
@@ -193,7 +195,6 @@ def run(
193195
builder_config.set_flag(trt.BuilderFlag.INT8)
194196

195197
if sparse_weights:
196-
assert fp16_mode or int8_mode, "We can only enable sparsity in fp16 or int8 mode."
197198
builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)
198199

199200
if strict_type_constraints:

fx/lower.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def lower_to_trt(
8080
max_workspace_size=1 << 25,
8181
explicit_batch_dimension=False,
8282
fp16_mode=True,
83+
int8_mode=False,
8384
verbose_log=False,
8485
timing_cache_prefix="",
8586
save_timing_cache=False,
@@ -96,6 +97,7 @@ def lower_to_trt(
9697
max_workspace_size: Maximum size of workspace given to TensorRT.
9798
explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.
9899
fp16_mode: fp16 config given to TRTModule.
100+
int8_mode: int8 config given to TRTModule.
99101
verbose_log: Enable verbose log for TensorRT if set True.
100102
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
101103
save_timing_cache: Update timing cache with current timing cache data if set to True.
@@ -109,6 +111,7 @@ def lower_to_trt(
109111
max_workspace_size=max_workspace_size,
110112
explicit_batch_dimension=explicit_batch_dimension,
111113
fp16_mode=fp16_mode,
114+
int8_mode=int8_mode,
112115
verbose_log=verbose_log,
113116
timing_cache_prefix=timing_cache_prefix,
114117
save_timing_cache=save_timing_cache,

0 commit comments

Comments
 (0)