Skip to content

Commit 98b8ae1

Browse files
authored
Statically Quantize Image Encoder
Differential Revision: D61043280 Pull Request resolved: #4648
1 parent e2ca877 commit 98b8ae1

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

examples/models/llava/export_llava.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,21 +121,22 @@ def forward(self, images):
121121
llava_image_encode = LlavaImageEncoder(llava)
122122

123123
# quantizer
124-
linear_quantizer = XNNPACKQuantizer()
125-
operator_config_dynamic = get_symmetric_quantization_config(
126-
is_per_channel=True, is_dynamic=True
127-
)
128-
linear_quantizer.set_global(operator_config_dynamic)
124+
quantizer = XNNPACKQuantizer()
125+
quantizer.set_global(get_symmetric_quantization_config())
129126

130-
manager = LlavaEdgeManager(
131-
model=llava_image_encode,
132-
modelname="llava_image_encoder",
133-
max_seq_len=llava.text_model_args.max_seq_len, # This may not be right
134-
dtype=DType.fp32,
135-
use_kv_cache=True,
136-
example_inputs=(resized,),
137-
dynamic_shapes=dynamic_shapes,
138-
).capture_pre_autograd_graph()
127+
manager = (
128+
LlavaEdgeManager(
129+
model=llava_image_encode,
130+
modelname="llava_image_encoder",
131+
max_seq_len=llava.text_model_args.max_seq_len, # This may not be right
132+
dtype=DType.fp32,
133+
use_kv_cache=True,
134+
example_inputs=(resized,),
135+
dynamic_shapes=dynamic_shapes,
136+
)
137+
.capture_pre_autograd_graph()
138+
.pt2e_quantize([quantizer])
139+
)
139140

140141
# lower to executorch
141142
with torch.no_grad():
@@ -186,9 +187,11 @@ def main():
186187
llava_model = LlavaModel(use_sdpa_with_kv_cache_op=args.use_sdpa_with_kv_cache)
187188
llava = llava_model.get_eager_model()
188189

189-
prompt_before_image, resized, prompt_after_image = (
190-
llava_model.get_inputs_for_prefill()
191-
)
190+
(
191+
prompt_before_image,
192+
resized,
193+
prompt_after_image,
194+
) = llava_model.get_inputs_for_prefill()
192195

193196
image_encoder_ep = export_image_encoder(
194197
llava, resized, llava_model._get_image_dynamic_shapes()
@@ -211,9 +214,7 @@ def main():
211214
"text_model": text_model_ep,
212215
},
213216
partitioner={
214-
"image_encoder": [
215-
XnnpackPartitioner(config_precisions=ConfigPrecisionType.FP32)
216-
],
217+
"image_encoder": [XnnpackPartitioner()],
217218
"text_model": [
218219
XnnpackPartitioner(
219220
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,

0 commit comments

Comments
 (0)