File tree Expand file tree Collapse file tree 2 files changed +5
-1
lines changed Expand file tree Collapse file tree 2 files changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -155,6 +155,8 @@ def run(
155
155
timing_cache = None ,
156
156
profiling_verbosity = None ,
157
157
) -> TRTInterpreterResult :
158
+ assert not (fp16_mode and int8_mode ), "We cannot enable both fp16 and int8 mode."
159
+
158
160
TRT_INTERPRETER_CALL_PRE_OBSERVER .observe (self .module )
159
161
160
162
# For float outputs, we set their dtype to fp16 only if fp16_mode=True and
@@ -193,7 +195,6 @@ def run(
193
195
builder_config .set_flag (trt .BuilderFlag .INT8 )
194
196
195
197
if sparse_weights :
196
- assert fp16_mode or int8_mode , "We can only enable sparsity in fp16 or int8 mode."
197
198
builder_config .set_flag (trt .BuilderFlag .SPARSE_WEIGHTS )
198
199
199
200
if strict_type_constraints :
Original file line number Diff line number Diff line change @@ -80,6 +80,7 @@ def lower_to_trt(
80
80
max_workspace_size = 1 << 25 ,
81
81
explicit_batch_dimension = False ,
82
82
fp16_mode = True ,
83
+ int8_mode = False ,
83
84
verbose_log = False ,
84
85
timing_cache_prefix = "" ,
85
86
save_timing_cache = False ,
@@ -96,6 +97,7 @@ def lower_to_trt(
96
97
max_workspace_size: Maximum size of workspace given to TensorRT.
97
98
explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension.
98
99
fp16_mode: fp16 config given to TRTModule.
100
+ int8_mode: int8 config given to TRTModule.
99
101
verbose_log: Enable verbose log for TensorRT if set True.
100
102
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
101
103
save_timing_cache: Update timing cache with current timing cache data if set to True.
@@ -109,6 +111,7 @@ def lower_to_trt(
109
111
max_workspace_size = max_workspace_size ,
110
112
explicit_batch_dimension = explicit_batch_dimension ,
111
113
fp16_mode = fp16_mode ,
114
+ int8_mode = int8_mode ,
112
115
verbose_log = verbose_log ,
113
116
timing_cache_prefix = timing_cache_prefix ,
114
117
save_timing_cache = save_timing_cache ,
You can’t perform that action at this time.
0 commit comments