Skip to content

Commit 36c474a

Browse files
committed
Fix lint
1 parent 880416e commit 36c474a

File tree

5 files changed

+81
-46
lines changed

5 files changed

+81
-46
lines changed

backends/apple/mps/test/test_mps.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1190,7 +1190,11 @@ def forward(self, x, rows, columns):
11901190
x = torch.arange(0, 12).resize(4, 3)
11911191
rows = torch.tensor([[0, 0], [3, 3]])
11921192
columns = torch.tensor([[0, 2], [0, 2]])
1193-
model_inputs = (x, rows, columns, )
1193+
model_inputs = (
1194+
x,
1195+
rows,
1196+
columns,
1197+
)
11941198

11951199
self.lower_and_test_with_partitioner(
11961200
module, model_inputs, func_name=inspect.stack()[0].function[5:]

backends/apple/mps/test/test_mps_indexing_ops.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def forward(self, x):
6464
return x[:, [0, 1, 0], [0, 1, 0]]
6565

6666
module = IndexGet()
67-
model_inputs = (torch.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]),)
67+
model_inputs = (
68+
torch.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]),
69+
)
6870

6971
self.lower_and_test_with_partitioner(
7072
module, model_inputs, func_name=inspect.stack()[0].function[5:]
@@ -212,9 +214,12 @@ def forward(self, x, y, z):
212214
input = torch.ones(1, 8, 128, 8)
213215
indices = torch.tensor([1])
214216
values = torch.randn(8, 1, 8)
215-
model_inputs = (input, indices, values, )
217+
model_inputs = (
218+
input,
219+
indices,
220+
values,
221+
)
216222

217223
self.lower_and_test_with_partitioner(
218224
module, model_inputs, func_name=inspect.stack()[0].function[5:]
219225
)
220-

backends/apple/mps/test/test_mps_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,11 @@
2222
from executorch.exir.backend.backend_details import CompileSpec
2323
from executorch.exir.capture._config import ExecutorchBackendConfig
2424
from executorch.exir.tracer import Value
25-
from executorch.extension.pytree import tree_flatten
2625
from executorch.sdk import BundledProgram
2726
from executorch.sdk.bundled_program.config import MethodTestCase, MethodTestSuite
2827
from executorch.sdk.bundled_program.serialize import (
2928
serialize_from_bundled_program_to_flatbuffer,
3029
)
31-
from torch._export import capture_pre_autograd_graph
3230
from torch.export import export, ExportedProgram
3331

3432
# Config for Capturing the weights, will be moved in the future
@@ -201,7 +199,7 @@ def lower_module_and_test_output(
201199
func_name: str,
202200
use_partitioner: bool = True,
203201
use_fp16: bool = False,
204-
bundled_program = True,
202+
bundled_program=True,
205203
) -> ExirExportedProgram:
206204
"""
207205
Helper testing function that takes a torch.nn.Module and lowers it to MPS with

examples/apple/mps/scripts/bench_utils.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
# Provided subject to the LICENSE file in the top level directory.
44
#
55

6-
from torch._export.exported_program import ExportedProgram
7-
import torch
8-
import time
96
import logging
7+
import time
8+
9+
import torch
10+
from torch._export.exported_program import ExportedProgram
11+
1012

1113
def assert_outputs_equal(model_output, ref_output):
1214
"""
@@ -19,16 +21,19 @@ def assert_outputs_equal(model_output, ref_output):
1921
# Compare the result from executor and eager mode direclty
2022
if isinstance(ref_output, tuple) or isinstance(ref_output, list):
2123
# Multiple outputs executor always returns tuple, even if there is one output
22-
assert len(ref_output) == len(model_output), "Length of outputs is not matching!"
24+
assert len(ref_output) == len(
25+
model_output
26+
), "Length of outputs is not matching!"
2327
for i in range(len(ref_output)):
24-
assert(
25-
torch.allclose(
26-
model_output[i], ref_output[i], atol=1e-03, rtol=1e-03
27-
)
28+
assert torch.allclose(
29+
model_output[i], ref_output[i], atol=1e-03, rtol=1e-03
2830
)
2931
else:
3032
# If one output, eager returns tensor while executor tuple of size 1
31-
assert torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03), "Outputs are not matching!"
33+
assert torch.allclose(
34+
model_output[0], ref_output, atol=1e-03, rtol=1e-03
35+
), "Outputs are not matching!"
36+
3237

3338
def bench_forward(func, *args):
3439
# warmup
@@ -41,39 +46,44 @@ def bench_forward(func, *args):
4146
end = time.time()
4247
return end - start
4348

49+
4450
def executorch_forward_pass(model, inputs):
4551
for _ in range(10):
4652
model.forward(inputs)
4753

54+
4855
def synchronize():
4956
torch.mps.synchronize()
5057

58+
5159
def pytorch_forward_pass(model, inputs):
5260
for _ in range(10):
5361
model(*inputs)
5462
synchronize()
5563

64+
5665
def get_mps_inputs(inputs):
5766
inputs_mps = []
5867
for tensor in inputs:
5968
inputs_mps.append(tensor.to("mps"))
6069
inputs_mps = tuple(inputs_mps)
6170
return inputs_mps
6271

72+
6373
def get_executorch_model(executorch_program: ExportedProgram):
6474
try:
6575
from executorch.extension.pybindings.portable_lib import ( # @manual
6676
_load_for_executorch_from_buffer,
6777
)
68-
return _load_for_executorch_from_buffer(
69-
executorch_program.buffer
70-
)
78+
79+
return _load_for_executorch_from_buffer(executorch_program.buffer)
7180
except ImportError:
7281
logging.info(
7382
"ExecuTorch MPS delegate was built without pybind support (not possible to run forward pass within python)"
7483
)
7584
return None
7685

86+
7787
def bench_torch(executorch_program: ExportedProgram, model, inputs, model_name):
7888
model = model.to("mps")
7989
inputs_mps = get_mps_inputs(inputs)
@@ -86,7 +96,10 @@ def bench_torch(executorch_program: ExportedProgram, model, inputs, model_name):
8696
logging.info(f"Model name: {model_name}")
8797
logging.info(f"Pytorch MPS forward pass: {t_pytorch} seconds")
8898
logging.info(f"ExecuTorch MPS forward pass: {t_executorch} seconds")
89-
logging.info(f"ExecuTorch speedup: {((t_pytorch - t_executorch) / t_pytorch) * 100}%")
99+
logging.info(
100+
f"ExecuTorch speedup: {((t_pytorch - t_executorch) / t_pytorch) * 100}%"
101+
)
102+
90103

91104
def compare_outputs(executorch_program: ExportedProgram, model, inputs, model_name):
92105
inputs_copy = []
@@ -99,4 +112,6 @@ def compare_outputs(executorch_program: ExportedProgram, model, inputs, model_na
99112
if executorch_model is not None:
100113
executorch_results = executorch_model.forward(inputs_copy)
101114
assert_outputs_equal(executorch_results, pytorch_results)
102-
logging.info(F"Results between ExecuTorch forward pass with MPS backend and PyTorch forward pass are matching!")
115+
logging.info(
116+
f"Results between ExecuTorch forward pass with MPS backend and PyTorch forward pass for {model_name} are matching!"
117+
)

examples/apple/mps/scripts/mps_example.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111

1212
import torch
13+
from examples.apple.mps.scripts.bench_utils import bench_torch, compare_outputs
1314
from executorch import exir
1415
from executorch.backends.apple.mps.mps_preprocess import MPSBackend
1516
from executorch.backends.apple.mps.partition.mps_partitioner import MPSPartitioner
@@ -27,10 +28,6 @@
2728
from executorch.sdk.bundled_program.serialize import (
2829
serialize_from_bundled_program_to_flatbuffer,
2930
)
30-
from examples.apple.mps.scripts.bench_utils import (
31-
bench_torch,
32-
compare_outputs,
33-
)
3431

3532
from ....models import MODEL_NAME_TO_MODEL
3633
from ....models.model_factory import EagerModelFactory
@@ -40,7 +37,28 @@
4037
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
4138
logging.basicConfig(level=logging.INFO, format=FORMAT)
4239

43-
if __name__ == "__main__":
40+
41+
def get_bundled_program(executorch_program, example_inputs, expected_output):
42+
method_test_suites = [
43+
MethodTestSuite(
44+
method_name="forward",
45+
test_cases=[
46+
MethodTestCase(
47+
inputs=example_inputs, expected_outputs=[expected_output]
48+
)
49+
],
50+
)
51+
]
52+
logging.info(f"Expected output: {expected_output}")
53+
54+
bundled_program = BundledProgram(executorch_program, method_test_suites)
55+
bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer(
56+
bundled_program
57+
)
58+
return bundled_program_buffer
59+
60+
61+
def parse_args():
4462
parser = argparse.ArgumentParser()
4563
parser.add_argument(
4664
"-m",
@@ -111,10 +129,10 @@
111129
)
112130

113131
args = parser.parse_args()
132+
return args
114133

115-
if args.model_name not in MODEL_NAME_TO_MODEL:
116-
raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.")
117134

135+
def get_model_config(args):
118136
model_config = {}
119137
model_config["module_name"] = MODEL_NAME_TO_MODEL[args.model_name][0]
120138
model_config["model_class_name"] = MODEL_NAME_TO_MODEL[args.model_name][1]
@@ -125,10 +143,17 @@
125143
if args.params:
126144
model_config["params"] = args.params
127145
model_config["use_kv_cache"] = True
146+
return model_config
128147

129-
model, example_inputs, _ = EagerModelFactory.create_model(
130-
**model_config
131-
)
148+
149+
if __name__ == "__main__":
150+
args = parse_args()
151+
152+
if args.model_name not in MODEL_NAME_TO_MODEL:
153+
raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.")
154+
155+
model_config = get_model_config(args)
156+
model, example_inputs, _ = EagerModelFactory.create_model(**model_config)
132157

133158
model = model.eval()
134159
if args.check_correctness or args.bench_pytorch:
@@ -172,21 +197,9 @@
172197
model_name = f"{args.model_name}_mps"
173198

174199
if args.bundled:
175-
method_test_suites = [
176-
MethodTestSuite(
177-
method_name="forward",
178-
test_cases=[
179-
MethodTestCase(
180-
inputs=example_inputs, expected_outputs=[model(*example_inputs)]
181-
)
182-
],
183-
)
184-
]
185-
logging.info(f"Expected output: {model(*example_inputs)}")
186-
187-
bundled_program = BundledProgram(executorch_program, method_test_suites)
188-
bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer(
189-
bundled_program
200+
expected_output = model(*example_inputs)
201+
bundled_program_buffer = get_bundled_program(
202+
executorch_program, example_inputs, expected_output
190203
)
191204
model_name = f"{model_name}_bundled"
192205
extension = "fp16"

0 commit comments

Comments
 (0)