Skip to content

Commit c731673

Browse files
DenisVieriu97facebook-github-bot
authored andcommitted
Fix static llama AOT tracing (#2800)
Summary: Fix static llama tracing for MPS backend: - **AOT** tracing: ``` python3 -m examples.apple.mps.scripts.mps_example --model_name "llama2" --use_partitioner --no-use_fp16 --checkpoint stories110M.pt --params params.json ``` **Testing:** AOT: ``` python3 -m examples.apple.mps.scripts.mps_example --model_name "llama2" --use_partitioner --no-use_fp16 --checkpoint stories110M.pt --params params.json -b ``` Runtime: ``` ./cmake-out/examples/apple/mps/mps_executor_runner --model_path ../../pytorch-executorch/executorch/llama2_mps_bundled_fp32.pte --bundled_program --num_runs 0 Output 0: tensor(sizes=[1, 1, 512], [ 0.271252, 0.547134, -0.319407, -0.519001, -0.817092, -0.511795, 1.19794, 0.169265, 0.74455, -0.59659, -0.754759, -0.688871, 0.298589, -0.443414, -0.086443, -0.565953, 0.864259, 0.115826, 0.85074, -0.525906, 0.125811, -0.00499783, -0.463692, 0.29455, -0.910827, 0.539898, -0.958917, -0.160505, 0.345872, -0.33394, -0.481556, 1.23247, 0.972564, 1.2528, 0.74983, 0.117909, -0.537549, -0.11374, 0.0533589, -0.724433, 0.267418, 0.383111, -0.589426, -0.273412, 0.536997, -0.296996, -0.037801, -0.467146, -0.111444, 0.2018, 0.138447, -0.168977, 0.0633859, -0.147186, 0.162552, -0.340015, -0.481295, -0.285551, -0.486715, 0.266989, 0.178192, -0.15646, -0.538977, -0.161788, -0.844837, -0.0201809, -0.0330472, -1.00392, -0.683174, -0.352963, 0.0247648, 0.100644, 0.0750983, 0.617042, -0.0325033, 0.80563, 0.722587, -0.16737, -0.448206, -0.726605, 0.0371236, -0.340444, 0.0844629, -0.744782, 0.713093, -1.23478, -0.624044, -0.927914, 0.226109, 0.871968, -0.181462, 0.737743, 0.630924, 0.352844, 1.22101, -0.433692, -0.545998, 1.25375, 0.604229, 0.567072, ..., -0.664146, -0.325717, 0.0465028, -0.0638815, -0.414189, -0.0774017, 0.664497, -0.80469, 0.665282, 0.142916, -0.136235, -0.181776, -0.0936792, 0.149341, 0.102156, 0.259336, 0.881158, 0.785714, -0.345178, -0.165404, 0.100859, -0.43653, 0.239712, -0.111407, 0.744646, 0.651979, -0.15812, 0.0528999, 0.699308, 0.0331134, -0.102761, -0.0878261, -0.67926, -0.211967, 0.852889, -0.395876, 0.0924324, -0.262803, -0.622059, 0.550099, 0.162211, -0.531762, 0.0518005, -0.920345, -0.014961, 0.132838, -0.0850867, 0.49853, -1.38389, 1.04421, 0.220865, 0.160585, -0.239085, -0.745911, 1.19387, -0.597559, -0.722064, 0.267607, 0.65336, -0.360622, -0.20821, -0.522872, -0.40434, -0.593412, 0.918845, 1.05659, 0.266254, -0.438084, 0.626653, -0.232121, -0.346624, -0.533966, 0.179427, -0.509745, -0.187721, -0.253874, -0.286669, -0.462519, 0.569676, 0.0893508, -1.06513, 0.639664, 0.266718, -0.609418, 0.347195, 0.159792, 0.229622, -0.488133, 0.0085784, 0.485227, 0.0878157, 0.556075, -0.933147, 0.742958, -0.064349, 0.198085, -0.314378, 0.173344, -0.718609, -1.14167, ]) I 00:00:02.547035 executorch:mps_executor_runner.mm:535] Model verified successfully. ``` cc cccclai , shoumikhin Pull Request resolved: #2800 Reviewed By: shoumikhin Differential Revision: D55615735 Pulled By: cccclai fbshipit-source-id: 07c86ffbf941e688a7e5c5025aa4587afc260629
1 parent c06c89f commit c731673

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

backends/apple/mps/operators/constant_ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@ def define_node(
5555
elif node.target == exir_ops.edge.aten.empty.memory_format:
5656
fill_value = 0
5757
elif node.target == exir_ops.edge.aten.scalar_tensor.default:
58-
fill_value = float(node.args[0])
58+
fill_value = cast(float, node.args[0])
59+
60+
if fill_value == float("-inf"):
61+
fill_value = "-inf"
62+
elif fill_value == float("inf"):
63+
fill_value = "inf"
5964

6065
dtype = MPSDataType.mps_data_type_float32
6166
if node.kwargs and "dtype" in node.kwargs and node.kwargs["dtype"] is not None:

backends/apple/mps/operators/node_visitor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,9 @@ def define_scalar(
157157
"""
158158
assert isinstance(val, int) or isinstance(val, float)
159159

160-
if val in self.tensor_to_id:
161-
return self.tensor_to_id[val]
160+
id = len(mps_graph.mps_values)
161+
self.tensor_to_id[val] = id
162162

163-
id = self.get_serialized_id(val, mps_graph)
164163
tensor = torch.tensor(val)
165164
constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data(
166165
tensor, mps_graph, mps_data_type, id

backends/apple/mps/serialization/mps_graph_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ class MPSFull:
391391

392392
@dataclass
393393
class MPSFullLike(MPSNode1x1):
394-
fill_value: float = 0.0
394+
fill_value: Union[float, str] = 0.0
395395
dtype: MPSDataType = MPSDataType.mps_data_type_float32
396396

397397

0 commit comments

Comments
 (0)