Skip to content

Commit 5db96fb

Browse files
lucylqfacebook-github-bot
authored andcommitted
Update schema to include infinity for double values
Summary: Replace previous solution of parsing the json and replacing instances of "Infinity" with "inf" Differential Revision: D62393242
1 parent eca9ed5 commit 5db96fb

File tree

3 files changed

+35
-13
lines changed

3 files changed

+35
-13
lines changed

exir/_serialize/_flatbuffer.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,6 @@ def _is_valid_alignment(alignment: int) -> bool:
2929
return alignment > 0 and (alignment & (alignment - 1)) == 0
3030

3131

32-
# TODO(T182299196): Replace this hack with a proper flatc binary.
33-
def _replace_infinity_in_json_file(content: str) -> str:
34-
content = re.sub(
35-
r'"double_val"\s*:\s*(-)?Infinity', r'"double_val": "\g<1>inf"', content
36-
)
37-
return content
38-
39-
4032
def _patch_schema_alignment(
4133
schema: bytes,
4234
constant_tensor_alignment: Optional[int],
@@ -291,11 +283,8 @@ def _program_json_to_flatbuffer(
291283
json_path = os.path.join(temp_dir, file_stem + ".json")
292284
output_path = os.path.join(temp_dir, file_stem + ".pte")
293285

294-
# TODO(T182299196): Replace this hack with a proper flatc binary.
295-
replaced_program_json = _replace_infinity_in_json_file(program_json)
296-
297286
with open(json_path, "wb") as json_file:
298-
json_file.write(replaced_program_json.encode("ascii"))
287+
json_file.write(program_json.encode("ascii"))
299288

300289
try:
301290
_flatc_compile(temp_dir, schema_info.root_path, json_path)

exir/emit/test/test_emit.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from executorch.exir.schema import (
3636
Bool,
3737
DelegateCall,
38+
Double,
3839
EValue,
3940
ExecutionPlan,
4041
Int,
@@ -1620,3 +1621,27 @@ def forward(self, x):
16201621
executorch_module = _load_for_executorch_from_buffer(model.buffer)
16211622
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1))
16221623
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1)
1624+
1625+
def test_infinity_in_model(self) -> None:
1626+
class InfinityMaskModel(nn.Module):
1627+
def __init__(self):
1628+
super().__init__()
1629+
self.mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.float32)
1630+
1631+
def forward(self, x):
1632+
masked_weights = x.masked_fill(self.mask == 0, float("-inf"))
1633+
return masked_weights
1634+
1635+
model = to_edge(
1636+
export(
1637+
InfinityMaskModel(),
1638+
(torch.randn(2, 2),),
1639+
)
1640+
)
1641+
1642+
# Confirm that we can serialize the model with infinity in it.
1643+
model = model.to_executorch()
1644+
1645+
# Assert that the infinity is stored as a string "-inf".
1646+
values = model.executorch_program.execution_plan[0].values
1647+
self.assertEqual(values[5].val, Double(double_val=float("-inf")))

exir/schema.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,15 @@ class Bool:
7575

7676
@dataclass
7777
class Double:
78-
double_val: float
78+
double_val: Union[float, str]
79+
80+
def __init__(self, double_val: float) -> None:
81+
if double_val == float("inf"):
82+
self.double_val = "inf"
83+
elif double_val == float("-inf"):
84+
self.double_val = "-inf"
85+
else:
86+
self.double_val = double_val
7987

8088

8189
@dataclass

0 commit comments

Comments
 (0)