-
Notifications
You must be signed in to change notification settings - Fork 8.4k
Security Update and Enhancement for run.py #264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
# Copyright 2024 X.AI Corp. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
@@ -30,54 +30,59 @@ def validate_checkpoint(path, expected_hash): | |
|
||
|
||
def main(): | ||
# Validate checkpoint integrity | ||
# Validate checkpoint integrity | ||
validate_checkpoint(CKPT_PATH, CKPT_HASH) | ||
grok_1_model = LanguageModelConfig( | ||
vocab_size=128 * 1024, | ||
pad_token=0, | ||
eos_token=2, | ||
sequence_len=8192, | ||
embedding_init_scale=1.0, | ||
output_multiplier_scale=0.5773502691896257, | ||
embedding_multiplier_scale=78.38367176906169, | ||
model=TransformerConfig( | ||
emb_size=48 * 128, | ||
widening_factor=8, | ||
key_size=128, | ||
num_q_heads=48, | ||
num_kv_heads=8, | ||
num_layers=64, | ||
attn_output_multiplier=0.08838834764831845, | ||
shard_activations=True, | ||
# MoE. | ||
num_experts=8, | ||
num_selected_experts=2, | ||
# Activation sharding. | ||
data_axis="data", | ||
model_axis="model", | ||
), | ||
) | ||
inference_runner = InferenceRunner( | ||
pad_sizes=(1024,), | ||
runner=ModelRunner( | ||
model=grok_1_model, | ||
bs_per_device=0.125, | ||
checkpoint_path=CKPT_PATH, | ||
# Limit inference rate | ||
inference_runner.rate_limit = 100 | ||
), | ||
|
||
name="local", | ||
load=CKPT_PATH, | ||
tokenizer_path="./tokenizer.model", | ||
local_mesh_config=(1, 8), | ||
between_hosts_config=(1, 1), | ||
) | ||
inference_runner.initialize() | ||
gen = inference_runner.run() | ||
|
||
grok_1_model = LanguageModelConfig( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only change here is a dedent from the PEP8 standard 4 space indent. |
||
vocab_size=128 * 1024, | ||
pad_token=0, | ||
eos_token=2, | ||
sequence_len=8192, | ||
embedding_init_scale=1.0, | ||
output_multiplier_scale=0.5773502691896257, | ||
embedding_multiplier_scale=78.38367176906169, | ||
model=TransformerConfig( | ||
emb_size=48 * 128, | ||
widening_factor=8, | ||
key_size=128, | ||
num_q_heads=48, | ||
num_kv_heads=8, | ||
num_layers=64, | ||
attn_output_multiplier=0.08838834764831845, | ||
shard_activations=True, | ||
# MoE. | ||
num_experts=8, | ||
num_selected_experts=2, | ||
# Activation sharding. | ||
data_axis="data", | ||
model_axis="model", | ||
), | ||
) | ||
|
||
inference_runner = InferenceRunner( | ||
pad_sizes=(1024,), | ||
runner=ModelRunner( | ||
model=grok_1_model, | ||
bs_per_device=0.125, | ||
checkpoint_path=CKPT_PATH, | ||
# Limit inference rate | ||
inference_runner.rate_limit = 100 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This appears to reference |
||
), | ||
|
||
name="local", | ||
load=CKPT_PATH, | ||
tokenizer_path="./tokenizer.model", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you were to improve anything, I'd suggest improving how file paths are defined by utilizing |
||
local_mesh_config=(1, 8), | ||
between_hosts_config=(1, 1), | ||
) | ||
|
||
inference_runner.initialize() | ||
|
||
gen = inference_runner.run() | ||
|
||
inp = "The answer to life the universe and everything is of course" | ||
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) | ||
|
||
inp = "The answer to life the universe and everything is of course" | ||
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) | ||
# Add authentication | ||
@app.route("/inference") | ||
@auth.login_required | ||
MiChaelinzo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -87,7 +92,7 @@ def inference(): | |
gen = inference_runner.run() | ||
|
||
# Rest of inference code | ||
|
||
if __name__ == "__main__": | ||
logging.basicConfig(level=logging.INFO) | ||
main() | ||
main() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2 space indent is not standard. Please view PEP8 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm using 1 space, and you should comment that to the original repo, you're making our lives very complicated enough with your reviews that doesn't make any sense at all! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This change is not accepted and requires fixing. 1 space indent is not standard and worsens readability and consistently in all code affected. Not accepted. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Overall, a complete waste of a PR. Nothing of value was added. |
Uh oh!
There was an error while loading. Please reload this page.