Skip to content

Commit 0f92498

Browse files
lucylqfacebook-github-bot
authored andcommitted
Update schema to include infinity for double values (#5203)
Summary: Pull Request resolved: #5203 Replace previous solution of parsing the json and replacing instances of "Infinity" with "inf" Differential Revision: D62393242
1 parent 7650667 commit 0f92498

File tree

5 files changed

+76
-20
lines changed

5 files changed

+76
-20
lines changed

exir/_serialize/_dataclass.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class Example
9595
cls_flds = fields(cls)
9696
data = {}
9797
for field in cls_flds:
98+
print("field: ", field.name, "type: ", field.type, "default: ", field.default)
9899
key = field.name
99100
T = field.type
100101

@@ -129,6 +130,11 @@ class Example
129130
data[key] = [_json_to_dataclass(e, T) for e in value]
130131
continue
131132

133+
if get_origin(T) is Union:
134+
res = [x for x in get_args(get_type_hints(cls)[key]) if x == type(value)]
135+
data[key] = res[0](value)
136+
continue
137+
132138
# If T is an enum then lookup the value in the enum otherwise try to
133139
# cast value to whatever type is required
134140
if isinstance(T, enum.EnumMeta):

exir/_serialize/_flatbuffer.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,6 @@ def _is_valid_alignment(alignment: int) -> bool:
2929
return alignment > 0 and (alignment & (alignment - 1)) == 0
3030

3131

32-
# TODO(T182299196): Replace this hack with a proper flatc binary.
33-
def _replace_infinity_in_json_file(content: str) -> str:
34-
content = re.sub(
35-
r'"double_val"\s*:\s*(-)?Infinity', r'"double_val": "\g<1>inf"', content
36-
)
37-
return content
38-
39-
4032
def _patch_schema_alignment(
4133
schema: bytes,
4234
constant_tensor_alignment: Optional[int],
@@ -291,11 +283,8 @@ def _program_json_to_flatbuffer(
291283
json_path = os.path.join(temp_dir, file_stem + ".json")
292284
output_path = os.path.join(temp_dir, file_stem + ".pte")
293285

294-
# TODO(T182299196): Replace this hack with a proper flatc binary.
295-
replaced_program_json = _replace_infinity_in_json_file(program_json)
296-
297286
with open(json_path, "wb") as json_file:
298-
json_file.write(replaced_program_json.encode("ascii"))
287+
json_file.write(program_json.encode("ascii"))
299288

300289
try:
301290
_flatc_compile(temp_dir, schema_info.root_path, json_path)
@@ -330,6 +319,13 @@ def _program_json_to_flatbuffer(
330319
)
331320

332321

322+
def _replace_infinity_in_json_file(content: str) -> str:
323+
content = re.sub(
324+
r'"double_val"\s*:\s*(-)?inf', r'"double_val": "\g<1>inf"', content
325+
)
326+
return content
327+
328+
333329
def _program_flatbuffer_to_json(program_flatbuffer: bytes) -> bytes:
334330
"""Converts binary flatbuffer data into Program-compatible JSON.
335331
@@ -347,5 +343,7 @@ def _program_flatbuffer_to_json(program_flatbuffer: bytes) -> bytes:
347343
bin_file.write(program_flatbuffer)
348344

349345
_flatc_decompile(temp_dir, schema_info.root_path, bin_path)
350-
with open(json_path, "rb") as output_file:
351-
return output_file.read()
346+
with open(json_path, "r") as output_file:
347+
json_data = output_file.read()
348+
replaced_json_data = _replace_infinity_in_json_file(json_data)
349+
return bytes(replaced_json_data, "ascii")

exir/_serialize/_program.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import re
1212

1313
from dataclasses import dataclass
14-
from typing import ClassVar, List, Literal, Optional, Tuple
14+
from typing import ClassVar, List, Literal, Optional, Tuple, Union
1515

1616
from executorch.exir._serialize._cord import Cord
1717
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
@@ -47,7 +47,12 @@ def _program_to_json(program: Program) -> str:
4747
def _json_to_program(program_json: bytes) -> Program:
4848
"""Returns a Program deserialized from the given JSON string."""
4949
# construct program class recursively from dict
50-
return _json_to_dataclass(json.loads(program_json), cls=Program)
50+
print("program_json")
51+
print(program_json)
52+
x = json.loads(program_json)
53+
print("json.loads(program_json)")
54+
print(x)
55+
return _json_to_dataclass(x, cls=Program)
5156

5257

5358
def _padding_required(offset: int, alignment: int) -> int:
@@ -571,9 +576,8 @@ def deserialize_pte_binary(program_data: bytes) -> Program:
571576
segment_base_offset = eh.segment_base_offset
572577

573578
# Parse the flatbuffer data.
574-
program: Program = _json_to_program(
575-
_program_flatbuffer_to_json(program_data[:program_size])
576-
)
579+
x = _program_flatbuffer_to_json(program_data[:program_size])
580+
program: Program = _json_to_program(x)
577581

578582
if segment_base_offset != 0:
579583
# Move segment data back into the Program.

exir/emit/test/test_emit.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ExecutorchProgramManager,
2424
to_edge,
2525
)
26+
from executorch.exir._serialize._program import deserialize_pte_binary
2627
from executorch.exir.backend.backend_api import to_backend
2728
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2829
from executorch.exir.dialects._ops import ops as exir_ops
@@ -35,6 +36,7 @@
3536
from executorch.exir.schema import (
3637
Bool,
3738
DelegateCall,
39+
Double,
3840
EValue,
3941
ExecutionPlan,
4042
Int,
@@ -1620,3 +1622,33 @@ def forward(self, x):
16201622
executorch_module = _load_for_executorch_from_buffer(model.buffer)
16211623
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1))
16221624
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1)
1625+
1626+
def test_infinity_in_model(self) -> None:
1627+
class InfinityMaskModel(nn.Module):
1628+
def __init__(self):
1629+
super().__init__()
1630+
self.mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.float32)
1631+
1632+
def forward(self, x):
1633+
masked_weights = x.masked_fill(self.mask == 0, float("-inf"))
1634+
return masked_weights
1635+
1636+
model = to_edge(
1637+
export(
1638+
InfinityMaskModel(),
1639+
(torch.randn(2, 2),),
1640+
)
1641+
)
1642+
1643+
# Confirm that we can serialize the model with infinity in it.
1644+
model = model.to_executorch()
1645+
1646+
# Assert that the infinity is stored as a string "-inf".
1647+
values = model.executorch_program.execution_plan[0].values
1648+
self.assertEqual(values[5].val, Double(double_val=float("-inf")))
1649+
1650+
# Confirm that we can also deserialize the model with infinity in it.
1651+
pte_data = deserialize_pte_binary(model.buffer)
1652+
self.assertEqual(
1653+
pte_data.execution_plan, model.executorch_program.execution_plan
1654+
)

exir/schema.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,23 @@ class Bool:
7575

7676
@dataclass
7777
class Double:
78-
double_val: float
78+
double_val: Union[float, str]
79+
80+
def __init__(self, double_val: float) -> None:
81+
if double_val == float("inf"):
82+
self.double_val = "inf"
83+
elif double_val == float("-inf"):
84+
self.double_val = "-inf"
85+
else:
86+
self.double_val = double_val
87+
88+
def __post_init__(self) -> None:
89+
if isinstance(self.double_val, str):
90+
assert self.double_val in ["inf", "-inf"]
91+
else:
92+
assert isinstance(self.double_val, float)
93+
assert not self.double_val == float("inf")
94+
assert not self.double_val == float("-inf")
7995

8096

8197
@dataclass

0 commit comments

Comments
 (0)