Skip to content

Commit dbb13fa

Browse files
DenisVieriu97dbort
authored andcommitted
Fix variable number of inputs in mps_example.py (#918)
Summary: - fix models with variable number of inputs (>1). Needed for the models `add`, `linear`, `add_mul` where the inputs array is usually made out of 2/3 tensors. - llama2 is also working correctly through MPS delegate. Set the `atol` to 1e-4 when doing the bundled program comparison check with the eager output from CPU. Pull Request resolved: #918 Reviewed By: kirklandsign Differential Revision: D50301445 Pulled By: kimishpatel fbshipit-source-id: 38576363de60fc45dabc92833723cf8f293cc7c5
1 parent 3e0130b commit dbb13fa

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

examples/apple/mps/executor_runner/mps_executor_runner.mm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,10 +481,11 @@ MemoryManager memory_manager(
481481
strstr(model_path, "emformer_transcribe") ||
482482
strstr(model_path, "emformer_join") ||
483483
strstr(model_path, "edsr") ||
484+
strstr(model_path, "llama2") ||
484485
strstr(model_path, "ic3") ||
485486
strstr(model_path, "ic4")) {
486487
atol = 1e-04;
487-
}
488+
}
488489
status = torch::executor::util::VerifyResultWithBundledExpectedOutput(
489490
*method,
490491
file_data->data(),

examples/apple/mps/scripts/mps_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def __init__(self):
7575
super().__init__()
7676
self.mps_module = lowered_module
7777

78-
def forward(self, input_args):
79-
return self.mps_module(input_args)
78+
def forward(self, *input_args):
79+
return self.mps_module(*input_args)
8080

8181
executorch_program = (
8282
exir.capture(

0 commit comments

Comments
 (0)