Skip to content

Commit c5a385e

Browse files
authored
Update schema to include infinity for double values
Differential Revision: D62393242 Pull Request resolved: #5203
1 parent f412630 commit c5a385e

File tree

4 files changed

+72
-14
lines changed

4 files changed

+72
-14
lines changed

exir/_serialize/_dataclass.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,13 @@ class Example
129129
data[key] = [_json_to_dataclass(e, T) for e in value]
130130
continue
131131

132+
# If T is a Union, then check which type in the Union it is and initialize.
133+
# eg. Double type in schema.py
134+
if get_origin(T) is Union:
135+
res = [x for x in get_args(get_type_hints(cls)[key]) if x == type(value)]
136+
data[key] = res[0](value)
137+
continue
138+
132139
# If T is an enum then lookup the value in the enum otherwise try to
133140
# cast value to whatever type is required
134141
if isinstance(T, enum.EnumMeta):

exir/_serialize/_flatbuffer.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,6 @@ def _is_valid_alignment(alignment: int) -> bool:
2929
return alignment > 0 and (alignment & (alignment - 1)) == 0
3030

3131

32-
# TODO(T182299196): Replace this hack with a proper flatc binary.
33-
def _replace_infinity_in_json_file(content: str) -> str:
34-
content = re.sub(
35-
r'"double_val"\s*:\s*(-)?Infinity', r'"double_val": "\g<1>inf"', content
36-
)
37-
return content
38-
39-
4032
def _patch_schema_alignment(
4133
schema: bytes,
4234
constant_tensor_alignment: Optional[int],
@@ -291,11 +283,8 @@ def _program_json_to_flatbuffer(
291283
json_path = os.path.join(temp_dir, file_stem + ".json")
292284
output_path = os.path.join(temp_dir, file_stem + ".pte")
293285

294-
# TODO(T182299196): Replace this hack with a proper flatc binary.
295-
replaced_program_json = _replace_infinity_in_json_file(program_json)
296-
297286
with open(json_path, "wb") as json_file:
298-
json_file.write(replaced_program_json.encode("ascii"))
287+
json_file.write(program_json.encode("ascii"))
299288

300289
try:
301290
_flatc_compile(temp_dir, schema_info.root_path, json_path)
@@ -330,6 +319,19 @@ def _program_json_to_flatbuffer(
330319
)
331320

332321

322+
def _replace_infinity_in_json_file(content: bytes) -> bytes:
323+
"""Replace -inf and inf with "inf" and "-inf" in the JSON file. program.fbs
324+
is used to convert from flatbuffer to JSON. +-inf float values are not
325+
supported by JSON, so we replace them with the string equivalent. When
326+
converting from JSON to python dataclasses, the string is read as a Union
327+
of float and string (see schema.py).
328+
"""
329+
content = re.sub(
330+
rb'"double_val"\s*:\s*(-)?inf', rb'"double_val": "\g<1>inf"', content
331+
)
332+
return content
333+
334+
333335
def _program_flatbuffer_to_json(program_flatbuffer: bytes) -> bytes:
334336
"""Converts binary flatbuffer data into Program-compatible JSON.
335337
@@ -348,4 +350,5 @@ def _program_flatbuffer_to_json(program_flatbuffer: bytes) -> bytes:
348350

349351
_flatc_decompile(temp_dir, schema_info.root_path, bin_path)
350352
with open(json_path, "rb") as output_file:
351-
return output_file.read()
353+
json_data = output_file.read()
354+
return _replace_infinity_in_json_file(json_data)

exir/emit/test/test_emit.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ExecutorchProgramManager,
2424
to_edge,
2525
)
26+
from executorch.exir._serialize._program import deserialize_pte_binary
2627
from executorch.exir.backend.backend_api import to_backend
2728
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2829
from executorch.exir.dialects._ops import ops as exir_ops
@@ -35,6 +36,7 @@
3536
from executorch.exir.schema import (
3637
Bool,
3738
DelegateCall,
39+
Double,
3840
EValue,
3941
ExecutionPlan,
4042
Int,
@@ -1620,3 +1622,33 @@ def forward(self, x):
16201622
executorch_module = _load_for_executorch_from_buffer(model.buffer)
16211623
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1))
16221624
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1)
1625+
1626+
def test_infinity_in_model(self) -> None:
1627+
class InfinityMaskModel(nn.Module):
1628+
def __init__(self):
1629+
super().__init__()
1630+
self.mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.float32)
1631+
1632+
def forward(self, x):
1633+
masked_weights = x.masked_fill(self.mask == 0, float("-inf"))
1634+
return masked_weights
1635+
1636+
model = to_edge(
1637+
export(
1638+
InfinityMaskModel(),
1639+
(torch.randn(2, 2),),
1640+
)
1641+
)
1642+
1643+
# Confirm that we can serialize the model with infinity in it.
1644+
model = model.to_executorch()
1645+
1646+
# Assert that the infinity is stored as a string "-inf".
1647+
values = model.executorch_program.execution_plan[0].values
1648+
self.assertEqual(values[5].val, Double(double_val=float("-inf")))
1649+
1650+
# Confirm that we can also deserialize the model with infinity in it.
1651+
pte_data = deserialize_pte_binary(model.buffer)
1652+
self.assertEqual(
1653+
pte_data.execution_plan, model.executorch_program.execution_plan
1654+
)

exir/schema.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,23 @@ class Bool:
7575

7676
@dataclass
7777
class Double:
78-
double_val: float
78+
double_val: Union[float, str]
79+
80+
def __init__(self, double_val: float) -> None:
81+
if double_val == float("inf"):
82+
self.double_val = "inf"
83+
elif double_val == float("-inf"):
84+
self.double_val = "-inf"
85+
else:
86+
self.double_val = double_val
87+
88+
def __post_init__(self) -> None:
89+
if isinstance(self.double_val, str):
90+
assert self.double_val in ["inf", "-inf"]
91+
else:
92+
assert isinstance(self.double_val, float)
93+
assert not self.double_val == float("inf")
94+
assert not self.double_val == float("-inf")
7995

8096

8197
@dataclass

0 commit comments

Comments
 (0)