5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import logging
8
- from typing import Any , Callable , Tuple
8
+ from typing import Any , Tuple
9
9
10
10
import torch
11
11
16
16
17
17
18
18
def export_program (
19
- model : Callable ,
19
+ model : torch . nn . Module ,
20
20
inputs : Any ,
21
- pt2_quant : bool = False ,
22
21
) -> ExportedProgram :
23
- # we don't support training mode. Make it eval
24
- if hasattr (model , "eval" ):
25
- if pt2_quant :
26
- # pyre-fixme[6]: Incompatible parameter type.
27
- torch .ao .quantization .move_exported_model_to_eval (model )
28
- else :
29
- # pyre-fixme[16]: Anonymous callable has no attribute `eval`.
30
- model .eval ()
31
-
32
- # if it's already an ExportedProgram, just return it
33
- if isinstance (model , ExportedProgram ):
34
- return model
35
-
36
22
assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
37
23
24
+ # If the model is already a GraphModule (most likely from quantization), call the
25
+ # suggested torch.ao.quantization API instead, which only does dropout and batchnorm.
26
+ if isinstance (model , torch .fx .GraphModule ):
27
+ torch .ao .quantization .move_exported_model_to_eval (model )
28
+ else :
29
+ # We don't support training mode. Make it eval
30
+ if hasattr (model , "eval" ):
31
+ model .eval ()
32
+
38
33
# Prevent mkldnn decompositions
39
34
torch ._C ._set_mkldnn_enabled (False )
40
35
@@ -44,13 +39,12 @@ def export_program(
44
39
45
40
# Export the model and lower it it edge IR.
46
41
def export_to_edge (
47
- model : Callable ,
42
+ model : torch . nn . Module ,
48
43
inputs : Any ,
49
- pt2_quant : bool = False ,
50
44
dump_graphs : bool = False ,
51
45
) -> Tuple [EdgeProgramManager , ExportedProgram ]:
52
46
# Export the model into an ExportedProgram.
53
- expo_program = export_program (model , inputs , pt2_quant )
47
+ expo_program = export_program (model , inputs )
54
48
55
49
if dump_graphs :
56
50
logging .info (f"Exported graph:\n { expo_program .graph_module .graph } " )
0 commit comments