Skip to content

Commit 484774f

Browse files
fix eager run for cuda (#6429)
ghstack-source-id: 8278f05 Pull Request resolved: #6365 Co-authored-by: Lunwen He <[email protected]>
1 parent ca47839 commit 484774f

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

examples/models/llama/runner/eager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ def __init__(self, args):
3333
use_kv_cache=args.use_kv_cache,
3434
**params,
3535
)
36-
super().__init__(tokenizer_path=args.tokenizer_path, model_args=model_args)
37-
manager: LLMEdgeManager = _prepare_for_llama_export("llama", args)
38-
self.model = (
39-
manager.model.eval().to(device="cuda")
40-
if torch.cuda.is_available()
41-
else manager.model.eval().to(device="cpu")
36+
super().__init__(
37+
tokenizer_path=args.tokenizer_path,
38+
model_args=model_args,
39+
device="cuda" if torch.cuda.is_available() else "cpu",
4240
)
41+
manager: LLMEdgeManager = _prepare_for_llama_export("llama", args)
42+
self.model = manager.model.eval().to(device=self.device)
4343

4444
def forward(
4545
self,

examples/models/llama/runner/generation.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
5151

5252

5353
class LlamaRunner(ABC):
54-
def __init__(self, tokenizer_path: str, model_args: ModelArgs):
54+
def __init__(self, tokenizer_path: str, model_args: ModelArgs, device: str = "cpu"):
5555
self.params = model_args
5656
self.tokenizer = get_tokenizer(tokenizer_path)
5757
assert model_args.vocab_size == self.tokenizer.n_words
58+
self.device = device
5859

5960
@abstractmethod
6061
def forward(
@@ -73,9 +74,9 @@ def generate( # noqa: C901
7374
) -> List[int]:
7475
# prefill
7576
logits = self.forward(
76-
tokens=torch.tensor([prompt_tokens], dtype=torch.long),
77+
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
7778
input_pos=(
78-
torch.tensor([0], dtype=torch.long)
79+
torch.tensor([0], dtype=torch.long, device=self.device)
7980
if self.params.use_kv_cache
8081
else None
8182
),
@@ -87,14 +88,21 @@ def generate( # noqa: C901
8788
while len(tokens) < self.params.max_seq_len:
8889
if self.params.use_kv_cache:
8990
logits = self.forward(
90-
tokens=torch.tensor([[current_token]], dtype=torch.long),
91-
input_pos=torch.tensor([len(tokens) - 1], dtype=torch.long),
91+
tokens=torch.tensor(
92+
[[current_token]], dtype=torch.long, device=self.device
93+
),
94+
input_pos=torch.tensor(
95+
[len(tokens) - 1], dtype=torch.long, device=self.device
96+
),
9297
)
9398
else:
94-
logits = self.forward(tokens=torch.tensor([tokens], dtype=torch.long))
99+
logits = self.forward(
100+
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
101+
)
95102
current_token = next_token(logits, temperature, top_p)
96103
if current_token == self.tokenizer.eos_id or (
97-
hasattr(self, "stop_tokens") and current_token in self.stop_tokens
104+
hasattr(self.tokenizer, "stop_tokens")
105+
and current_token in self.tokenizer.stop_tokens
98106
):
99107
break
100108
tokens.append(current_token)

0 commit comments

Comments
 (0)