Skip to content

Commit 88f4d23

Browse files
nvzhihanjmrmhodak
andauthored
Add mixtral dockerfile and standalone inference script (#2029)
* Add dockerfile and standalone accuracy evaluation scripts * Minor fixes --------- Co-authored-by: Miro <[email protected]>
1 parent c40d6e1 commit 88f4d23

File tree

6 files changed

+552
-8
lines changed

6 files changed

+552
-8
lines changed

language/llama2-70b/evaluate-accuracy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def main():
5959
checkpoint_path = args.checkpoint_path
6060
metric = evaluate.load("rouge")
6161
nltk.download("punkt")
62+
nltk.download("punkt_tab")
6263

6364
tokenizer = AutoTokenizer.from_pretrained(
6465
checkpoint_path,

language/mixtral-8x7b/Dockerfile

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
15+
FROM nvcr.io/nvidia/pytorch:24.07-py3
1616
SHELL ["/bin/bash", "-c"]
1717

1818
ENV LC_ALL=C.UTF-8
@@ -22,7 +22,7 @@ ENV TZ=US/Pacific
2222
ENV DEBIAN_FRONTEND=noninteractive
2323

2424
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
25-
RUN rm -rf /var/lib/apt/lists/* && rm /etc/apt/sources.list.d/* \
25+
RUN rm -rf /var/lib/apt/lists/* && rm -rf /etc/apt/sources.list.d/* \
2626
&& apt update \
2727
&& apt install -y --no-install-recommends build-essential autoconf \
2828
libtool git ccache curl wget pkg-config sudo ca-certificates \
@@ -44,5 +44,5 @@ WORKDIR /tmp
4444
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-Linux-x86_64.sh \
4545
&& bash Miniconda3-* -b -p /opt/miniconda3
4646
ENV PATH="$PATH:/opt/miniconda3/bin"
47-
RUN conda create -n llama2-70b python=3.10
47+
RUN conda create -n llm python=3.10
4848
RUN chmod -R 777 /opt/miniconda3

language/mixtral-8x7b/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
Please see the [new docs site](https://docs.mlcommons.org/inference/benchmarks/language/mixtral-8x7b) for an automated way to run this benchmark across different available implementations and do an end-to-end submission with or without docker.
13-
13+
1414
## Prepare environment
1515

1616
For a CPU-only run:
@@ -234,11 +234,11 @@ Recreating the enviroment for evaluating the quality metrics can be quite tediou
234234
```bash
235235
docker build . -f Dockerfile.eval -t evaluation
236236
```
237-
2. Run the docker in interactive mode and with
237+
2. Run the docker in interactive mode and with
238238
```bash
239-
sudo docker run -it -v $(pwd):/eval -t evaluation
239+
docker run -it --rm --net=host --runtime=nvidia --ipc=host -v $PWD:$PWD -w $PWD evaluation
240240
```
241-
3.
241+
3.
242242
```bash
243243
cd eval
244244
python -u evaluate-accuracy.py --checkpoint-path [path_to_model_checkpoint] \
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Mixtral reference standalone inference script
2+
3+
The reference output and accuracy can be checked using the standalone hugginface inference script following the instructions below:
4+
5+
```
6+
cd language/mixtral-8x7b
7+
docker build -t mlc-ngc .
8+
nvidia-docker run -it --rm --net=host --runtime=nvidia --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --cap-add=SYS_PTRACE --cap-add=SYS_ADMIN --cap-add=DAC_READ_SEARCH --security-opt seccomp=unconfined -w $PWD -v $PWD:$PWD -t mlc-ngc
9+
10+
pip install -r requirements.txt
11+
cd standalone_infer
12+
# Make sure the checkpoint and reference pickle file is already downloaded
13+
python3 hf_eval_all.py --input_pkl=09292024_mixtral_15k_mintoken2_v1.pkl --checkpoint_path=/raid/data/mlperf-llm/Mixtral-8x7B-Instruct-v0.1 --output_pkl=mixtral_8x7b_15000_greedy_reference_fp16_mintoken2.pkl --batch_size=64
14+
15+
# Exit the container and enter the evaluation container
16+
exit
17+
docker build . -f Dockerfile.eval -t evaluation
18+
docker run -it --rm --net=host --runtime=nvidia --ipc=host -v $PWD:$PWD -w $PWD evaluation
19+
cd standalone_infer
20+
python3 run_accuracy.py --results_path=mixtral_8x7b_15000_greedy_reference_fp16_mintoken2.pkl
21+
```
22+
23+
Expected output:
24+
```
25+
EM: 0.7366, correct: 3683 / 5000, gen_token_per_sample: 129.9604
26+
Evaluating OpenOrca score...
27+
OpenOrca score: {'rouge1': np.float64(45.5989), 'rouge2': np.float64(23.3526), 'rougeL': np.float64(30.4608), 'rougeLsum': np.float64(42.5396)}, gen_token_per_sample: 205.8656
28+
Evaluating MBXP score...
29+
100%|| 5000/5000 [02:33<00:00, 32.50it/s]
30+
Processed 5000 in 153.89411109898356s
31+
60.16% pass@1
32+
{'cpp': 381, 'typescript': 438, 'ruby': 419, 'python': 492, 'php': 809, 'javascript': 469} out of {'cpp': 743, 'typescript': 868, 'ruby': 846, 'python': 863, 'php': 846, 'javascript': 834}
33+
gen_tokens_per_sample: 98.7026
34+
```
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
17+
import torch
18+
import pandas as pd
19+
import time
20+
from pathlib import Path
21+
import argparse
22+
23+
24+
def run_infer(df, ckpt_path, bs):
25+
"""
26+
dataset GSM8K
27+
id train.548
28+
question Gary manages two Amazon distribution centers. ...
29+
input <s> [INST] As an expert problem solver solve s...
30+
ref_output The first center processes 10000 packages per ...
31+
gt_output 14000
32+
tok_input [1, 1, 28705, 733, 16289, 28793, 1136, 396, 75...
33+
tok_ref_output [415, 907, 4982, 9537, 28705, 28740, 28734, 28...
34+
stop_sequence </s>
35+
tok_stop_sequence [2]
36+
tok_input_len 662
37+
tok_ref_output_len 174
38+
Name: 0, dtype: object
39+
"""
40+
device = "cuda" # the device to load the model onto
41+
42+
# Load the model from local if possible.
43+
model_path = Path(ckpt_path)
44+
if not model_path.exists():
45+
raise RuntimeError(f"{ckpt_path} not existed. Please download the checkpoint from mlcommon")
46+
47+
tokenizer = AutoTokenizer.from_pretrained(
48+
model_path, padding_side="left", trust_remote_code=True)
49+
model = AutoModelForCausalLM.from_pretrained(
50+
model_path, device_map="auto", trust_remote_code=True)
51+
tokenizer.pad_token = tokenizer.eos_token
52+
tokenizer.pad_token_id = tokenizer.eos_token_id
53+
54+
# gen parameter. We stop at 1024. Starting from v5.0, min_token is set to 2 to avoid 0-output issue
55+
gen_kwargs = {
56+
# "min_new_tokens": 1,
57+
"min_new_tokens": 2,
58+
"max_new_tokens": 1024,
59+
"do_sample": False,
60+
"temperature": None,
61+
"top_p": None,
62+
}
63+
64+
# Start inference
65+
BS = bs
66+
bidx = 0
67+
model.eval()
68+
69+
input_tokens = []
70+
input_tokens_lens = []
71+
output_tokens = []
72+
output_tokens_lens = []
73+
output_texts = []
74+
75+
tic = time.time()
76+
for idx in range(0, len(df), BS):
77+
tac = time.time()
78+
print(f"Processing {idx}/{len(df)}, time: {tac - tic}s")
79+
sidx = idx
80+
eidx = min(sidx + BS, len(df))
81+
82+
# We use batch_encode_plus for batch inference.
83+
# Note 9/29/2024: Mixtral changed its tokenizer in Jun. Using the Feb 29 2024 version.
84+
batch_texts = df['input'][sidx:eidx].tolist()
85+
batch_ids = tokenizer.batch_encode_plus(batch_texts, return_tensors="pt", padding=True)
86+
# tok_input_length = batch_ids['attention_mask'].sum(
87+
# axis=1).to(torch.int32).tolist()
88+
# input_tokens_lens += tok_input_length
89+
tok_input_id = batch_ids['input_ids'].to(torch.int32).tolist()
90+
# Remove eos from the input id
91+
tok_input_id = [[element for element in sublist if element !=
92+
tokenizer.eos_token_id] for sublist in tok_input_id]
93+
input_tokens += tok_input_id
94+
tok_input_length = [len(seq) for seq in tok_input_id]
95+
input_tokens_lens += tok_input_length
96+
97+
batch_ids = batch_ids.to(device)
98+
_, length = batch_ids.input_ids.shape
99+
outputs = model.generate(**batch_ids, num_return_sequences=1,
100+
**gen_kwargs)
101+
102+
output_ids = outputs[:, length:].cpu().tolist()
103+
output_tokens += output_ids
104+
105+
# Filter out EOS
106+
id_filtered = [[num for num in sublist if num !=
107+
tokenizer.eos_token_id] for sublist in output_ids]
108+
output_id_len = [len(out) for out in id_filtered]
109+
output_tokens_lens += output_id_len
110+
111+
# Detokenizer
112+
output_msgs = tokenizer.batch_decode(
113+
output_ids, skip_special_tokens=True)
114+
output_texts += output_msgs
115+
bidx += 1
116+
117+
# Assemble the output
118+
output_df = df[:len(output_tokens)].copy()
119+
output_df["infer_tok_input"] = input_tokens
120+
output_df["infer_tok_input_length"] = input_tokens_lens
121+
output_df["infer_ref_output"] = output_texts
122+
output_df["infer_tok_ref_output"] = output_tokens
123+
output_df["infer_tok_ref_output_length"] = output_tokens_lens
124+
125+
# output_df.to_pickle(f"mixtral_8x7b_all15k_{len(output_tokens)}_BS{BS}_greedy_reference_fp16_mintoken1.pkl")
126+
127+
return output_df
128+
129+
def trim_twos(df):
130+
# Remove all trailing 2s except for 1
131+
def remove_trailing_twos(lst):
132+
count = 0
133+
for num in reversed(lst):
134+
if num == 2:
135+
count += 1
136+
else:
137+
break
138+
return lst[:-count] if count > 0 else lst
139+
140+
df['infer_tok_ref_output'] = df['infer_tok_ref_output'].apply(remove_trailing_twos)
141+
df['trim_lengths'] = df['infer_tok_ref_output'].apply(len)
142+
df['tok_ref_output'] = df['tok_ref_output'].apply(remove_trailing_twos)
143+
df['tok_ref_output_len'] = df['tok_ref_output'].apply(len)
144+
return df
145+
146+
def mbxp_stop(df):
147+
stop_tokens = [13, 13940, 28832, 13]
148+
def modify_list(lst):
149+
for i in range(len(lst) - len(stop_tokens) + 1):
150+
if lst[i:i+len(stop_tokens)] == stop_tokens:
151+
return lst[:i+len(stop_tokens)]
152+
return lst
153+
154+
df.loc[df['dataset'] == 'MBXP', 'infer_tok_ref_output'] = df[df['dataset'] == 'MBXP']['infer_tok_ref_output'].apply(modify_list)
155+
df['trim_lengths'] = df['infer_tok_ref_output'].apply(len)
156+
return df
157+
158+
159+
def fix_name(df):
160+
df.drop(columns=['ref_output'], inplace=True)
161+
df.drop(columns=['tok_ref_output'], inplace=True)
162+
df.drop(columns=['tok_ref_output_len'], inplace=True)
163+
df.drop(columns=['infer_tok_ref_output_length'], inplace=True)
164+
df.drop(columns=['infer_tok_input'], inplace=True)
165+
df.drop(columns=['infer_tok_input_length'], inplace=True)
166+
df.rename(columns={'infer_ref_output': 'ref_output'}, inplace=True)
167+
df.rename(columns={'infer_tok_ref_output': 'tok_ref_output'}, inplace=True)
168+
df.rename(columns={'trim_lengths': 'tok_ref_output_len'}, inplace=True)
169+
170+
return df
171+
172+
173+
if __name__ == "__main__":
174+
parser = argparse.ArgumentParser()
175+
parser.add_argument("--input_pkl", type=str, default="09292024_mixtral_15k_mintoken2_v1.pkl",
176+
help="The path to the input pkl file")
177+
parser.add_argument("--output_pkl", type=str, default="mixtral_8x7b_15000_greedy_reference_fp16_mintoken2.pkl",
178+
help="The path to the output pickle.")
179+
parser.add_argument("--checkpoint_path", type=str, default="/raid/data/mlperf-llm/Mixtral-8x7B-Instruct-v0.1",
180+
help="The path to the mixtral checkpoint")
181+
parser.add_argument("--batch_size", type=int, default=64,
182+
help="Batch size of the refernece inference")
183+
args = parser.parse_args()
184+
185+
df = pd.read_pickle(args.input_pkl)
186+
df = run_infer(df, args.checkpoint_path, args.batch_size)
187+
188+
df = trim_twos(df)
189+
df = mbxp_stop(df)
190+
df = fix_name(df)
191+
192+
df.to_pickle(args.output_pkl)
193+
194+

0 commit comments

Comments
 (0)