Skip to content

Commit d724673

Browse files
andrewor14facebook-github-bot
authored andcommitted
Fix GPTQ import error after torchao refactor (#3760)
Summary: Pull Request resolved: #3760 Fix broken import after pytorch/ao#275 Reviewed By: jerryzh168 Differential Revision: D57888168 fbshipit-source-id: 51a63131ae14e362991ef962df325ec24f958e2d
1 parent 1ad4ae6 commit d724673

File tree

1 file changed

+8
-2
lines changed
  • examples/models/llama2/source_transformation

1 file changed

+8
-2
lines changed

examples/models/llama2/source_transformation/quantize.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,13 @@ def quantize(
9696
if calibration_tasks is None:
9797
calibration_tasks = ["wikitext"]
9898

99-
from torchao.quantization.GPTQ import InputRecorder
99+
try:
100+
# torchao 0.3+
101+
# pyre-ignore
102+
from torchao._eval import InputRecorder
103+
except ImportError:
104+
from torchao.quantization.GPTQ import InputRecorder
105+
100106
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
101107

102108
if tokenizer_path is None:
@@ -107,7 +113,7 @@ def quantize(
107113
)
108114

109115
inputs = (
110-
InputRecorder(
116+
InputRecorder( # pyre-ignore
111117
tokenizer,
112118
calibration_seq_length,
113119
None, # input_prep_func

0 commit comments

Comments
 (0)