Skip to content

Commit 6fab1e5

Browse files
ydwu4facebook-github-bot
authored andcommitted
Try remove no_grad in exir capture.
Summary: Remove the no_grad flag in capture to see whether it's still relevant. Reviewed By: gmagogsfm, guangy10, angelayi Differential Revision: D48079330 fbshipit-source-id: 38d9158d0a2f1244584a9b165f7020b755ee3d24
1 parent a81505b commit 6fab1e5

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

exir/capture/_capture.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545

4646

4747
@compatibility(is_backward_compatible=False)
48-
@torch.no_grad()
4948
def capture(
5049
f: Callable[..., Any],
5150
args: Tuple[Value, ...],

exir/tests/test_tracer.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,9 @@ def __init__(self):
306306
super().__init__()
307307
self.register_buffer(
308308
"_bin_num_examples",
309-
torch.empty([42]).fill_(0.0),
309+
torch.empty([42]).fill_(
310+
0.0,
311+
),
310312
)
311313

312314
def forward(self, x, y, z):
@@ -327,8 +329,30 @@ def forward(self, x, y, z):
327329
torch.tensor(3.14),
328330
)
329331

332+
with self.assertRaisesRegex(
333+
RuntimeError,
334+
"Found a graph input that requires gradients, and received a mutation.",
335+
):
336+
_ = exir.capture(
337+
model,
338+
example_inputs,
339+
exir.CaptureConfig(
340+
pt2_mode=True,
341+
enable_aot=True,
342+
),
343+
)
344+
345+
# Note that model._bin_num_examples is mutated during exir.capture
346+
# We need to create a new_model
347+
new_model = Module()
348+
example_inputs = (
349+
torch.randn(4),
350+
torch.tensor(0),
351+
torch.tensor(3.14),
352+
)
353+
330354
ep = exir.capture(
331-
model,
355+
new_model,
332356
example_inputs,
333357
exir.CaptureConfig(
334358
pt2_mode=True,
@@ -342,7 +366,7 @@ def forward(self, x, y, z):
342366
torch.tensor(2.1),
343367
)
344368
graph_outputs = ep(*test_inputs)
345-
eager_outputs = model(*test_inputs)
369+
eager_outputs = new_model(*test_inputs)
346370
self.assertEqual(len(graph_outputs), 2)
347371
self.assertEqual(len(eager_outputs), 2)
348372
self.assertTrue(torch.allclose(graph_outputs[0], eager_outputs[0]))

0 commit comments

Comments
 (0)