Skip to content

Update schema to include infinity for double values #5203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions exir/_serialize/_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ class Example
data[key] = [_json_to_dataclass(e, T) for e in value]
continue

# If T is a Union, then check which type in the Union it is and initialize.
# eg. Double type in schema.py
if get_origin(T) is Union:
res = [x for x in get_args(get_type_hints(cls)[key]) if x == type(value)]
data[key] = res[0](value)
continue

# If T is an enum then lookup the value in the enum otherwise try to
# cast value to whatever type is required
if isinstance(T, enum.EnumMeta):
Expand Down
29 changes: 16 additions & 13 deletions exir/_serialize/_flatbuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,6 @@ def _is_valid_alignment(alignment: int) -> bool:
return alignment > 0 and (alignment & (alignment - 1)) == 0


# TODO(T182299196): Replace this hack with a proper flatc binary.
def _replace_infinity_in_json_file(content: str) -> str:
content = re.sub(
r'"double_val"\s*:\s*(-)?Infinity', r'"double_val": "\g<1>inf"', content
)
return content


def _patch_schema_alignment(
schema: bytes,
constant_tensor_alignment: Optional[int],
Expand Down Expand Up @@ -291,11 +283,8 @@ def _program_json_to_flatbuffer(
json_path = os.path.join(temp_dir, file_stem + ".json")
output_path = os.path.join(temp_dir, file_stem + ".pte")

# TODO(T182299196): Replace this hack with a proper flatc binary.
replaced_program_json = _replace_infinity_in_json_file(program_json)

with open(json_path, "wb") as json_file:
json_file.write(replaced_program_json.encode("ascii"))
json_file.write(program_json.encode("ascii"))

try:
_flatc_compile(temp_dir, schema_info.root_path, json_path)
Expand Down Expand Up @@ -330,6 +319,19 @@ def _program_json_to_flatbuffer(
)


def _replace_infinity_in_json_file(content: bytes) -> bytes:
"""Replace -inf and inf with "inf" and "-inf" in the JSON file. program.fbs
is used to convert from flatbuffer to JSON. +-inf float values are not
supported by JSON, so we replace them with the string equivalent. When
converting from JSON to python dataclasses, the string is read as a Union
of float and string (see schema.py).
"""
content = re.sub(
rb'"double_val"\s*:\s*(-)?inf', rb'"double_val": "\g<1>inf"', content
)
return content


def _program_flatbuffer_to_json(program_flatbuffer: bytes) -> bytes:
"""Converts binary flatbuffer data into Program-compatible JSON.

Expand All @@ -348,4 +350,5 @@ def _program_flatbuffer_to_json(program_flatbuffer: bytes) -> bytes:

_flatc_decompile(temp_dir, schema_info.root_path, bin_path)
with open(json_path, "rb") as output_file:
return output_file.read()
json_data = output_file.read()
return _replace_infinity_in_json_file(json_data)
32 changes: 32 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ExecutorchProgramManager,
to_edge,
)
from executorch.exir._serialize._program import deserialize_pte_binary
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -35,6 +36,7 @@
from executorch.exir.schema import (
Bool,
DelegateCall,
Double,
EValue,
ExecutionPlan,
Int,
Expand Down Expand Up @@ -1620,3 +1622,33 @@ def forward(self, x):
executorch_module = _load_for_executorch_from_buffer(model.buffer)
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1))
self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1)

def test_infinity_in_model(self) -> None:
class InfinityMaskModel(nn.Module):
def __init__(self):
super().__init__()
self.mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.float32)

def forward(self, x):
masked_weights = x.masked_fill(self.mask == 0, float("-inf"))
return masked_weights

model = to_edge(
export(
InfinityMaskModel(),
(torch.randn(2, 2),),
)
)

# Confirm that we can serialize the model with infinity in it.
model = model.to_executorch()

# Assert that the infinity is stored as a string "-inf".
values = model.executorch_program.execution_plan[0].values
self.assertEqual(values[5].val, Double(double_val=float("-inf")))

# Confirm that we can also deserialize the model with infinity in it.
pte_data = deserialize_pte_binary(model.buffer)
self.assertEqual(
pte_data.execution_plan, model.executorch_program.execution_plan
)
18 changes: 17 additions & 1 deletion exir/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,23 @@ class Bool:

@dataclass
class Double:
double_val: float
double_val: Union[float, str]

def __init__(self, double_val: float) -> None:
if double_val == float("inf"):
self.double_val = "inf"
elif double_val == float("-inf"):
self.double_val = "-inf"
else:
self.double_val = double_val

def __post_init__(self) -> None:
if isinstance(self.double_val, str):
assert self.double_val in ["inf", "-inf"]
else:
assert isinstance(self.double_val, float)
assert not self.double_val == float("inf")
assert not self.double_val == float("-inf")


@dataclass
Expand Down
Loading