Skip to content

Commit 889e313

Browse files
authored
Merge pull request pytorch#24 from daniil-lyakhov/dl/ov/comments
[OpenVINO backend] Apply quantization comments
2 parents a28fcf3 + ee54b2f commit 889e313

File tree

2 files changed

+37
-67
lines changed

2 files changed

+37
-67
lines changed

backends/openvino/quantizer/quantizer.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import nncf.experimental.torch.fx as nncf_fx
1414

1515
import torch.fx
16+
1617
from nncf.common.graph.graph import NNCFGraph
1718
from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
1819
from torch.ao.quantization.quantizer.quantizer import (
@@ -343,5 +344,39 @@ def validate(self, model: torch.fx.GraphModule) -> None:
343344
def transform_for_annotation(
344345
self, model: torch.fx.GraphModule
345346
) -> torch.fx.GraphModule:
347+
# Fold constant branches to avoid their quantization
346348
nncf_fx.transformations.fold_constant_except_qdq(model)
347349
return model
350+
351+
352+
def quantize_model(
353+
captured_model: torch.fx.GraphModule,
354+
calibration_dataset: torch.utils.data.DataLoader,
355+
) -> torch.fx.GraphModule:
356+
"""
357+
Quantizes a model using either NNCF-based or PTQ-based quantization.
358+
359+
:param captured_model: The model to be quantized, represented as a torch.fx.GraphModule.
360+
:param calibration_dataset: A DataLoader containing calibration data for quantization.
361+
:return: The quantized model as a torch.fx.GraphModule.
362+
"""
363+
quantizer = OpenVINOQuantizer()
364+
365+
print("PTQ: Quantize the model")
366+
default_subset_size = 300
367+
batch_size = calibration_dataset.batch_size
368+
subset_size = (default_subset_size // batch_size) + int(
369+
default_subset_size % batch_size > 0
370+
)
371+
372+
def transform(x):
373+
return x[0]
374+
375+
quantized_model = nncf_fx.quantize_pt2e(
376+
captured_model,
377+
quantizer,
378+
subset_size=subset_size,
379+
calibration_dataset=nncf.Dataset(calibration_dataset, transform_func=transform),
380+
fold_quantize=False,
381+
)
382+
return quantized_model

examples/openvino/aot/aot_openvino_compiler.py

Lines changed: 2 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,22 @@
88
import os
99
import shutil
1010
import subprocess
11-
from itertools import islice
1211
from pathlib import Path
1312

1413
import executorch
1514

16-
import nncf
15+
import nncf.torch
1716
import numpy as np
1817
import timm
1918
import torch
2019
import torchvision.models as torchvision_models
21-
from executorch.backends.openvino import OpenVINOQuantizer
2220
from executorch.backends.openvino.partitioner import OpenvinoPartitioner
21+
from executorch.backends.openvino.quantizer.quantizer import quantize_model
2322
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
2423
from executorch.exir.backend.backend_details import CompileSpec
25-
from nncf.experimental.torch.fx.quantization.quantize_pt2e import quantize_pt2e
2624
from sklearn.metrics import accuracy_score
2725
from timm.data import resolve_data_config
2826
from timm.data.transforms_factory import create_transform
29-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3027
from torch.export import export
3128
from torch.export.exported_program import ExportedProgram
3229
from torchvision import datasets
@@ -129,55 +126,6 @@ def dump_inputs(calibration_dataset, dest_path):
129126
return input_files, targets
130127

131128

132-
def quantize_model(
133-
captured_model: torch.fx.GraphModule,
134-
calibration_dataset: torch.utils.data.DataLoader,
135-
use_nncf: bool,
136-
) -> torch.fx.GraphModule:
137-
"""
138-
Quantizes a model using either NNCF-based or PTQ-based quantization.
139-
140-
:param captured_model: The model to be quantized, represented as a torch.fx.GraphModule.
141-
:param calibration_dataset: A DataLoader containing calibration data for quantization.
142-
:param use_nncf: Whether to use NNCF-based quantization (True) or standard PTQ (False).
143-
:return: The quantized model as a torch.fx.GraphModule.
144-
"""
145-
quantizer = OpenVINOQuantizer()
146-
147-
print("PTQ: Quantize the model")
148-
default_subset_size = 300
149-
batch_size = calibration_dataset.batch_size
150-
subset_size = (default_subset_size // batch_size) + int(
151-
default_subset_size % batch_size > 0
152-
)
153-
154-
def transform(x):
155-
return x[0]
156-
157-
if use_nncf:
158-
159-
quantized_model = quantize_pt2e(
160-
captured_model,
161-
quantizer,
162-
subset_size=subset_size,
163-
calibration_dataset=nncf.Dataset(
164-
calibration_dataset, transform_func=transform
165-
),
166-
fold_quantize=False,
167-
)
168-
else:
169-
annotated_model = prepare_pt2e(captured_model, quantizer)
170-
171-
print("PTQ: Calibrate the model...")
172-
for data in islice(calibration_dataset, subset_size):
173-
annotated_model(transform(data))
174-
175-
print("PTQ: Convert the quantized model...")
176-
quantized_model = convert_pt2e(annotated_model, fold_quantize=False)
177-
178-
return quantized_model
179-
180-
181129
def validate_model(
182130
model_file_name: str, calibration_dataset: torch.utils.data.DataLoader
183131
) -> float:
@@ -231,7 +179,6 @@ def main(
231179
dataset_path: str,
232180
device: str,
233181
batch_size: int,
234-
quantization_flow: str,
235182
):
236183
"""
237184
Main function to load, quantize, and validate a model.
@@ -244,7 +191,6 @@ def main(
244191
:param dataset_path: Path to the dataset for calibration/validation.
245192
:param device: The device to run the model on (e.g., "cpu", "gpu").
246193
:param batch_size: Batch size for dataset loading.
247-
:param quantization_flow: The quantization method to use.
248194
"""
249195

250196
# Load the selected model
@@ -281,7 +227,6 @@ def main(
281227
quantized_model = quantize_model(
282228
aten_dialect.module(),
283229
calibration_dataset,
284-
use_nncf=quantization_flow == "nncf",
285230
)
286231

287232
aten_dialect: ExportedProgram = export(quantized_model, example_args)
@@ -360,15 +305,6 @@ def main(
360305
default="CPU",
361306
help="Target device for compiling the model (e.g., CPU, GPU). Default is CPU.",
362307
)
363-
parser.add_argument(
364-
"--quantization_flow",
365-
type=str,
366-
choices=["pt2e", "nncf"],
367-
default="nncf",
368-
help="Select the quantization flow (nncf or pt2e):"
369-
" pt2e is the default torch.ao quantization flow, while"
370-
" nncf is a custom method with additional algorithms to improve model performance.",
371-
)
372308

373309
args = parser.parse_args()
374310

@@ -384,5 +320,4 @@ def main(
384320
args.dataset,
385321
args.device,
386322
args.batch_size,
387-
args.quantization_flow,
388323
)

0 commit comments

Comments
 (0)