Skip to content

Commit 41c0d5b

Browse files
angelayifacebook-github-bot
authored andcommitted
Fix handling of nonpersistent buffers
Differential Revision: D55268323
1 parent 8532e79 commit 41c0d5b

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

exir/lowered_backend_module.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,10 +483,22 @@ def _get_new_signature( # noqa: C901
483483
elif is_tagged:
484484
input_specs.append(orig_input_spec)
485485

486-
if orig_input_spec.kind in (InputKind.PARAMETER, InputKind.BUFFER):
486+
if orig_input_spec.kind == InputKind.PARAMETER:
487487
new_state_dict[orig_input_spec.target] = (
488488
original_program.state_dict[orig_input_spec.target]
489489
)
490+
elif (
491+
orig_input_spec.kind == InputKind.BUFFER
492+
and orig_input_spec.persistent
493+
):
494+
new_state_dict[orig_input_spec.target] = (
495+
original_program.state_dict[orig_input_spec.target]
496+
)
497+
elif orig_input_spec.kind == InputKind.BUFFER:
498+
assert not orig_input_spec.persistent
499+
new_constants[orig_input_spec.target] = original_program.constants[
500+
orig_input_spec.target
501+
]
490502
elif orig_input_spec.kind in (
491503
InputKind.CONSTANT_TENSOR,
492504
InputKind.CUSTOM_OBJ,

0 commit comments

Comments
 (0)