Skip to content

Commit 5429eea

Browse files
authored
add the ability to run eager runner via buck
Differential Revision: D64730344 Pull Request resolved: #6506
1 parent 85d3ff6 commit 5429eea

File tree

6 files changed

+41
-15
lines changed

6 files changed

+41
-15
lines changed

examples/models/llama/TARGETS

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,6 @@ runtime.python_library(
126126
runtime.python_binary(
127127
name = "eval_llama",
128128
main_function = "executorch.examples.models.llama.eval_llama.main",
129-
preload_deps = [
130-
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
131-
"//executorch/kernels/quantized:aot_lib",
132-
],
133129
deps = [
134130
":eval_library",
135131
"//caffe2:torch",

examples/models/llama/eval_llama_lib.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def eval_llama(
293293

294294
# Needed for loading mmlu dataset.
295295
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
296+
# pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
296297
if args.tasks and "mmlu" in args.tasks:
297298
import datasets
298299

@@ -302,7 +303,7 @@ def eval_llama(
302303
with torch.no_grad():
303304
eval_results = simple_evaluate(
304305
model=eval_wrapper,
305-
tasks=args.tasks, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
306+
tasks=args.tasks,
306307
num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
307308
limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
308309
)

examples/models/llama/runner/TARGETS

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,37 @@
11
# Any targets that should be shared between fbcode and xplat must be defined in
22
# targets.bzl. This file can contain fbcode-only targets.
33

4+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
45
load(":targets.bzl", "define_common_targets")
56

67
oncall("executorch")
78

89
define_common_targets()
10+
11+
runtime.python_library(
12+
name = "eager_runner_library",
13+
srcs = [
14+
"eager.py",
15+
"generation.py"
16+
],
17+
_is_external_target = True,
18+
base_module = "executorch.examples.models.llama.runner",
19+
visibility = [
20+
"//bento/...",
21+
"//bento_kernels/...",
22+
"//executorch/examples/...",
23+
"@EXECUTORCH_CLIENTS",
24+
],
25+
deps = [
26+
"//executorch/examples/models/llama:export_library",
27+
],
28+
)
29+
30+
runtime.python_binary(
31+
name = "eager",
32+
main_function = "executorch.examples.models.llama.runner.eager.main",
33+
deps = [
34+
":eager_runner_library",
35+
"//caffe2:torch",
36+
],
37+
)

examples/models/llama/runner/eager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@
99
from typing import Optional
1010

1111
import torch
12-
13-
from examples.models.llama.llama_transformer import ModelArgs
1412
from executorch.examples.models.llama.export_llama_lib import (
1513
_prepare_for_llama_export,
1614
build_args_parser as _build_args_parser,
1715
)
16+
from executorch.examples.models.llama.llama_transformer import ModelArgs
1817
from executorch.examples.models.llama.runner.generation import LlamaRunner
19-
from executorch.extension.llm.export import LLMEdgeManager
18+
from executorch.extension.llm.export.builder import LLMEdgeManager
2019

2120

2221
class EagerLlamaRunner(LlamaRunner):
@@ -43,8 +42,8 @@ def __init__(self, args):
4342

4443
def forward(
4544
self,
46-
tokens: Optional[torch.LongTensor] = None,
47-
input_pos: Optional[torch.LongTensor] = None,
45+
tokens: torch.Tensor,
46+
input_pos: Optional[torch.Tensor] = None,
4847
) -> torch.Tensor:
4948
return self.model.forward(tokens=tokens, input_pos=input_pos)
5049

examples/models/llama/runner/generation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
class CompletionPrediction(TypedDict, total=False):
1717
generation: str
18-
tokens: List[str] # not required
18+
tokens: List[int] # not required
1919

2020

2121
def sample_top_p(probs, p):
@@ -47,6 +47,7 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
4747
if temperature > 0:
4848
probs = torch.softmax(logits / temperature, dim=-1)
4949
return sample_top_p(probs, top_p).item()
50+
# Pyre-ignore[7]: Incompatible return type [7]: Expected `int` but got `Union[bool, float, int]`
5051
return torch.argmax(logits, dim=-1).item()
5152

5253

@@ -60,8 +61,8 @@ def __init__(self, tokenizer_path: str, model_args: ModelArgs, device: str = "cp
6061
@abstractmethod
6162
def forward(
6263
self,
63-
tokens: Optional[torch.LongTensor] = None,
64-
input_pos: Optional[torch.LongTensor] = None,
64+
tokens: torch.Tensor,
65+
input_pos: Optional[torch.Tensor] = None,
6566
) -> torch.Tensor:
6667
pass
6768

examples/models/llama/runner/native.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def __init__(self, args):
4242

4343
def forward(
4444
self,
45-
tokens: Optional[torch.LongTensor] = None,
46-
input_pos: Optional[torch.LongTensor] = None,
45+
tokens: torch.Tensor,
46+
input_pos: Optional[torch.Tensor] = None,
4747
) -> torch.Tensor:
4848
return (
4949
self.model.forward((tokens, input_pos))

0 commit comments

Comments
 (0)