11
11
import re
12
12
import shutil
13
13
import subprocess
14
+
14
15
import tempfile
15
16
16
17
from dataclasses import dataclass
17
18
from typing import Callable , Dict , List , Optional , Sequence
18
19
19
-
20
20
# If this environment variable is set to true, save the flatc input files when
21
21
# serialization fails.
22
22
_SAVE_FLATC_ENV : str = "ET_EXIR_SAVE_FLATC_INPUTS_ON_FAILURE"
@@ -29,6 +29,14 @@ def _is_valid_alignment(alignment: int) -> bool:
29
29
return alignment > 0 and (alignment & (alignment - 1 )) == 0
30
30
31
31
32
+ # TODO(T182299196): Replace this hack with a proper flatc binary.
33
+ def _replace_infinity_in_json_file (content : str ) -> str :
34
+ content = re .sub (
35
+ r'"double_val"\s*:\s*(-)?Infinity' , r'"double_val": "\g<1>inf"' , content
36
+ )
37
+ return content
38
+
39
+
32
40
def _patch_schema_alignment (
33
41
schema : bytes ,
34
42
constant_tensor_alignment : Optional [int ],
@@ -281,8 +289,11 @@ def _program_json_to_flatbuffer(
281
289
json_path = os .path .join (temp_dir , file_stem + ".json" )
282
290
output_path = os .path .join (temp_dir , file_stem + ".pte" )
283
291
292
+ # TODO(T182299196): Replace this hack with a proper flatc binary.
293
+ replaced_program_json = _replace_infinity_in_json_file (program_json )
294
+
284
295
with open (json_path , "wb" ) as json_file :
285
- json_file .write (program_json .encode ("ascii" ))
296
+ json_file .write (replaced_program_json .encode ("ascii" ))
286
297
287
298
try :
288
299
_flatc_compile (temp_dir , schema_info .root_path , json_path )
0 commit comments