Skip to content

Commit 9896618

Browse files
tugsbayasgalanfacebook-github-bot
authored andcommitted
Make Resnet work with rexportable flow (#40)
Differential Revision: D47890878 fbshipit-source-id: c98f38ab90a6ff57c4dc047adcc3bdff044ae25d
1 parent a81505b commit 9896618

File tree

6 files changed

+18
-24
lines changed

6 files changed

+18
-24
lines changed

backends/xnnpack/test/test_xnnpack_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from executorch.exir.passes.spec_prop_pass import SpecPropPass
3737
from executorch.exir.serialize import serialize_to_flatbuffer
38+
from executorch.exir.tracer import _default_decomposition_table
3839

3940
# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
4041
from executorch.extension.pybindings.portable import ( # @manual
@@ -312,19 +313,22 @@ def quantize_and_test_model_with_quantizer(
312313
):
313314
module.eval()
314315
# program capture
315-
capture_config = exir.CaptureConfig(
316-
pt2_mode=True, enable_functionalization=True
316+
m = torch._export.capture_pre_autograd_graph(
317+
module, example_inputs, decomp_table=_default_decomposition_table()
317318
)
318-
captured_program = exir.capture(module, example_inputs, config=capture_config)
319-
m = captured_program.exported_program.graph_module
320319

321320
quantizer = XNNPACKQuantizer()
322321
quantization_config = get_symmetric_quantization_config()
323322
quantizer.set_global(quantization_config)
324323
prepared = prepare_pt2e(m, quantizer)
325324
converted = convert_pt2e(prepared)
326325

327-
captured_program.exported_program.graph_module = converted
326+
captured_program = exir.capture(
327+
converted,
328+
example_inputs,
329+
config=exir.CaptureConfig(enable_aot=True, _unlift=True),
330+
)
331+
328332
edge_program = captured_program.to_edge(get_xnnpack_edge_compile_config())
329333
delegated_module = self.lower_module_and_test_output(
330334
module=edge_program,

exir/capture/TARGETS

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ python_library(
88
deps = [
99
":capture",
1010
":config",
11-
":unlift",
1211
],
1312
)
1413

@@ -19,7 +18,6 @@ python_library(
1918
],
2019
deps = [
2120
":config",
22-
":unlift",
2321
"//caffe2:torch",
2422
"//executorch/exir:error",
2523
"//executorch/exir:tracer",
@@ -40,13 +38,3 @@ python_library(
4038
"//executorch/exir/passes:lib",
4139
],
4240
)
43-
44-
python_library(
45-
name = "unlift",
46-
srcs = [
47-
"_unlift.py",
48-
],
49-
deps = [
50-
"//caffe2:torch",
51-
],
52-
)

exir/capture/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@
1212
EdgeCompileConfig,
1313
ExecutorchBackendConfig,
1414
)
15-
from executorch.exir.capture._unlift import unlift_exported_program_lifted_states
1615

1716
__all__ = [
1817
"capture",
1918
"capture_multiple",
2019
"CaptureConfig",
2120
"EdgeCompileConfig",
2221
"ExecutorchBackendConfig",
23-
"unlift_exported_program_lifted_states",
2422
]

exir/capture/_capture.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch
1414
import torch._export
1515
from executorch.exir.capture._config import CaptureConfig
16-
from executorch.exir.capture._unlift import unlift_exported_program_lifted_states
1716
from executorch.exir.error import ExportError, ExportErrorType, InternalError
1817
from executorch.exir.program import ExirExportedProgram, MultiMethodExirExportedProgram
1918
from executorch.exir.tracer import (
@@ -76,7 +75,7 @@ def capture(
7675
ep = ep.transform(ReplaceViewOpsWithViewCopyOpsPass())
7776
if not config._unlift:
7877
return ExirExportedProgram(ep, False)
79-
graph_module = unlift_exported_program_lifted_states(ep)
78+
graph_module = ep.module()
8079

8180
elif config.enable_dynamic_shape:
8281
graph_module, _ = dynamo_trace(

exir/delegate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,9 @@ def call_delegate_autograd(lowered_module, *args):
107107
def fake_requires_grad(var):
108108
if var is not None:
109109
var = var.detach()
110-
var.requires_grad = True
111-
return err_fn(var)
110+
if torch.is_floating_point(var) or torch.is_complex(var):
111+
var.requires_grad = True
112+
return var
112113

113114
return pytree.tree_map(fake_requires_grad, res)
114115

exir/pass_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,11 @@ def make_val(
134134
x = torch.dequantize(x)
135135

136136
try:
137-
fake_tensor = self.fake_tensor_mode.from_tensor(x)
137+
# TODO we should allocate static shapes
138+
# for param/buffer values
139+
fake_tensor = self.fake_tensor_mode.from_tensor(
140+
x, static_shapes=isinstance(x, torch.nn.Parameter)
141+
)
138142
except UnsupportedFakeTensorException:
139143
# TODO: This is just a workaround to get over the
140144
# x.as_subclass error

0 commit comments

Comments
 (0)