Skip to content

Commit c6b030e

Browse files
committed
Fix branch
1 parent 1548183 commit c6b030e

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

exir/program/_fake_program.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
import copy
810
from typing import Dict, Union
911

@@ -61,4 +63,5 @@ def update_to_real_program(
6163
"""Update the fake exported program to point to the real state dict. Modifies the
6264
fake exported program in-place.
6365
"""
64-
fake_exported_program._state_dict = real_exported_program.state_dict
66+
for k, v in real_exported_program.state_dict.items():
67+
fake_exported_program._state_dict[k] = v

exir/program/test/test_fake_program.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
78

89
import sys
910
import unittest
@@ -73,4 +74,6 @@ def test_fake_program(self) -> None:
7374

7475
update_to_real_program(fake_program, exported_program)
7576
self.assertEqual(exported_program.state_dict, fake_program.state_dict)
76-
self.assertEqual(id(exported_program.state_dict), id(fake_program.state_dict))
77+
self.assertEqual(
78+
exported_program.state_dict.keys(), fake_program.state_dict.keys()
79+
)

0 commit comments

Comments
 (0)