Skip to content

Commit b257953

Browse files
tugsbayasgalanfacebook-github-bot
authored andcommitted
Make MilanDictation and Resnet work with rexportable flow
Summary: X-link: pytorch/pytorch#106676 MilanDictation and Resnet uses "re-export" flow now. Also did some refactoring to make the code little cleaner Differential Revision: D47890878 fbshipit-source-id: 8e346e7d4070a968b5e608f4268973114ccac560
1 parent 416d114 commit b257953

File tree

6 files changed

+17
-25
lines changed

6 files changed

+17
-25
lines changed

backends/xnnpack/test/test_xnnpack_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -312,19 +312,21 @@ def quantize_and_test_model_with_quantizer(
312312
):
313313
module.eval()
314314
# program capture
315-
capture_config = exir.CaptureConfig(
316-
pt2_mode=True, enable_functionalization=True
317-
)
318-
captured_program = exir.capture(module, example_inputs, config=capture_config)
319-
m = captured_program.exported_program.graph_module
315+
captured_program = torch._export.export(module, example_inputs)
316+
m = captured_program.module()
320317

321318
quantizer = XNNPACKQuantizer()
322319
quantization_config = get_symmetric_quantization_config()
323320
quantizer.set_global(quantization_config)
324321
prepared = prepare_pt2e(m, quantizer)
325322
converted = convert_pt2e(prepared)
326323

327-
captured_program.exported_program.graph_module = converted
324+
capture_config = exir.CaptureConfig(enable_aot=True, _unlift=True)
325+
326+
captured_program = exir.capture(
327+
converted, example_inputs, config=capture_config
328+
)
329+
328330
edge_program = captured_program.to_edge(get_xnnpack_edge_compile_config())
329331
delegated_module = self.lower_module_and_test_output(
330332
module=edge_program,

exir/capture/TARGETS

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ python_library(
88
deps = [
99
":capture",
1010
":config",
11-
":unlift",
1211
],
1312
)
1413

@@ -19,7 +18,6 @@ python_library(
1918
],
2019
deps = [
2120
":config",
22-
":unlift",
2321
"//caffe2:torch",
2422
"//executorch/exir:error",
2523
"//executorch/exir:tracer",
@@ -40,13 +38,3 @@ python_library(
4038
"//executorch/exir/passes:lib",
4139
],
4240
)
43-
44-
python_library(
45-
name = "unlift",
46-
srcs = [
47-
"_unlift.py",
48-
],
49-
deps = [
50-
"//caffe2:torch",
51-
],
52-
)

exir/capture/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@
1212
EdgeCompileConfig,
1313
ExecutorchBackendConfig,
1414
)
15-
from executorch.exir.capture._unlift import unlift_exported_program_lifted_states
1615

1716
__all__ = [
1817
"capture",
1918
"capture_multiple",
2019
"CaptureConfig",
2120
"EdgeCompileConfig",
2221
"ExecutorchBackendConfig",
23-
"unlift_exported_program_lifted_states",
2422
]

exir/capture/_capture.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch
1414
import torch._export
1515
from executorch.exir.capture._config import CaptureConfig
16-
from executorch.exir.capture._unlift import unlift_exported_program_lifted_states
1716
from executorch.exir.error import ExportError, ExportErrorType, InternalError
1817
from executorch.exir.program import ExirExportedProgram, MultiMethodExirExportedProgram
1918
from executorch.exir.tracer import (
@@ -76,7 +75,7 @@ def capture(
7675
ep = ep.transform(ReplaceViewOpsWithViewCopyOpsPass())
7776
if not config._unlift:
7877
return ExirExportedProgram(ep, False)
79-
graph_module = unlift_exported_program_lifted_states(ep)
78+
graph_module = ep.module()
8079

8180
elif config.enable_dynamic_shape:
8281
graph_module, _ = dynamo_trace(

exir/delegate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,9 @@ def call_delegate_autograd(lowered_module, *args):
107107
def fake_requires_grad(var):
108108
if var is not None:
109109
var = var.detach()
110-
var.requires_grad = True
111-
return err_fn(var)
110+
if torch.is_floating_point(var) or torch.is_complex(var):
111+
var.requires_grad = True
112+
return var
112113

113114
return pytree.tree_map(fake_requires_grad, res)
114115

exir/pass_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,11 @@ def make_val(
134134
x = torch.dequantize(x)
135135

136136
try:
137-
fake_tensor = self.fake_tensor_mode.from_tensor(x)
137+
# TODO we should allocate static shapes
138+
# for param/buffer values
139+
fake_tensor = self.fake_tensor_mode.from_tensor(
140+
x, static_shapes=isinstance(x, torch.nn.Parameter)
141+
)
138142
except UnsupportedFakeTensorException:
139143
# TODO: This is just a workaround to get over the
140144
# x.as_subclass error

0 commit comments

Comments
 (0)