Skip to content

Commit 24b37f2

Browse files
Yanghan Wangfacebook-github-bot
authored andcommitted
make builder accessible (#3743)
Summary: Pull Request resolved: #3743 Make the `builder` accessible after calling `_export_llama`. Also add a member method to return most recent saved model path (which guarantee ending with `.pte`, so that user doesn't need to think about complicated logic to check if it's model name or file name). Reviewed By: cccclai Differential Revision: D57801568 fbshipit-source-id: 4317d85a3aa8e54e0919e385e20674ddacfbf512
1 parent 2b91eba commit 24b37f2

File tree

4 files changed

+25
-6
lines changed

4 files changed

+25
-6
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ runtime.python_library(
7979
"//bento/...",
8080
"//bento_kernels/...",
8181
"//executorch/examples/...",
82+
"@EXECUTORCH_CLIENTS",
8283
],
8384
deps = [
8485
"//caffe2:torch",

examples/models/llama2/builder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __init__(
156156
self.edge_manager: Optional[EdgeProgramManager] = None
157157
self.export_program = None
158158
self.output_dir = "."
159+
self._saved_pte_filename = None
159160

160161
def set_metadata(self, metadata: Optional[dict]) -> "LlamaEdgeManager":
161162
"""
@@ -388,4 +389,11 @@ def save_to_pte(self, output_name: str) -> None:
388389
output_name (Optional[str]): The name of the .pte file.
389390
"""
390391
assert output_name, "Need a valid output name"
391-
save_pte_program(self.export_program, output_name, self.output_dir)
392+
filename = save_pte_program(self.export_program, output_name, self.output_dir)
393+
self._saved_pte_filename = filename
394+
395+
def get_saved_pte_filename(self) -> Optional[str]:
396+
"""
397+
Return the filename of the most recenet saved .pte file. Return None if the model is not saved.
398+
"""
399+
return self._saved_pte_filename

examples/models/llama2/export_llama_lib.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,22 @@ def export_llama(modelname, args) -> str:
293293
from executorch.util.python_profiler import CProfilerFlameGraph
294294

295295
with CProfilerFlameGraph(args.profile_path):
296-
return _export_llama(modelname, args)
296+
builder = _export_llama(modelname, args)
297+
assert (
298+
filename := builder.get_saved_pte_filename()
299+
) is not None, "Fail to get file name from builder"
300+
return filename
297301
except ImportError:
298302
print(
299303
"Please run `pip install snakeviz` to install required dependencies for cProfiler flamegraph."
300304
)
301305
return ""
302306
else:
303-
return _export_llama(modelname, args)
307+
builder = _export_llama(modelname, args)
308+
assert (
309+
filename := builder.get_saved_pte_filename()
310+
) is not None, "Fail to get file name from builder"
311+
return filename
304312

305313

306314
def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
@@ -383,7 +391,7 @@ def get_quantizer_and_quant_params(args):
383391
return pt2e_quant_params, quantizers, quant_dtype
384392

385393

386-
def _export_llama(modelname, args) -> str: # noqa: C901
394+
def _export_llama(modelname, args) -> LlamaEdgeManager: # noqa: C901
387395
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
388396

389397
# export_to_edge
@@ -468,4 +476,4 @@ def _export_llama(modelname, args) -> str: # noqa: C901
468476

469477
builder.save_to_pte(output_file)
470478

471-
return output_file
479+
return builder

examples/portable/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def export_to_exec_prog(
102102

103103
def save_pte_program(
104104
prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
105-
) -> None:
105+
) -> str:
106106
if model_name.endswith(".pte"):
107107
filename = model_name
108108
else:
@@ -114,3 +114,5 @@ def save_pte_program(
114114
logging.info(f"Saved exported program to {filename}")
115115
except Exception as e:
116116
logging.error(f"Error while saving to {filename}: {e}")
117+
118+
return filename

0 commit comments

Comments
 (0)