Skip to content

Quantization, fp acceleration, and testing #572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions .github/workflows/more-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
name: Run parallel prefill

on:
pull_request:
push:
branches:
- main
workflow_dispatch:

jobs:
test-cuda:
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
runner: linux.g5.4xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "12.1"
script: |
echo "::group::Print machine info"
uname -a
echo "::endgroup::"

echo "::group::Install newer objcopy that supports --set-section-alignment"
yum install -y devtoolset-10-binutils
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
echo "::endgroup::"


echo "::group::Download checkpoints"
# Install requirements
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip3 install -r requirements.txt
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
echo "::endgroup::"

echo "::group::Download checkpoints"
mkdir -p checkpoints/stories15M
pushd checkpoints/stories15M
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
popd
echo "::endgroup::"

echo "::group::Run inference"
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
export MODEL_NAME=stories15M
export MODEL_DIR=/tmp

for DTYPE in bfloat16 float16 float32; do
###################################################################
# group with different temperatures
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0.9
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 1.0
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 100
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 200
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 500
###################################################################
# group with different temperatures and prefill, and compile
# and prefill compile
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0 --compile --compile-prefill
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0.9 --compile --compile-prefill
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 1.0 --compile --compile-prefill
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 100 --compile --compile-prefill
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 200 --compile --compile-prefill
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 500 --compile --compile-prefill
###################################################################
# group with different temperatures and sequential prefill
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0 --sequential-prefill
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0.9 --sequential-prefill
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 1.0 --sequential-prefill
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 100 --sequential-prefill
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 200 --sequential-prefill
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 500 --sequential-prefill
###################################################################
# group with different temperatures and prefill, and compile
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0 --sequential-prefill --compile
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 0.9 --sequential-prefill --compile
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --temperature 1.0 --sequential-prefill --compile
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 100 --sequential-prefill --compile
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 200 --sequential-prefill --compile
python generate.py --checkpoint-path ${MODEL_PATH} --device cpu --dtype ${DTYPE} --top-k 500 --sequential-prefill --compile

done

echo "tests complete"
echo "******************************************"
echo "::endgroup::"

13 changes: 12 additions & 1 deletion build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class BuilderArgs:
setup_caches: bool = False
use_tp: bool = False
is_chat_model: bool = False
prefill_possible: bool = False

def __post_init__(self):
if self.device is None:
Expand Down Expand Up @@ -68,6 +69,8 @@ def __post_init__(self):
print(
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
)
if not (self.dso_path) and not (self.pte_path):
self.prefill_possible = True

@classmethod
def from_args(cls, args): # -> BuilderArgs:
Expand Down Expand Up @@ -114,6 +117,14 @@ def from_args(cls, args): # -> BuilderArgs:
if "chat" in path_basename or "instruct" in path_basename:
is_chat_model = True

if args.output_pte_path and args.dtype.startswith("fast"):
if args.dtype == "fast":
dtype = torch.float32
else:
dtype = torch.float16
else:
dtype = name_to_dtype(args.dtype)

return cls(
checkpoint_dir=checkpoint_dir,
checkpoint_path=checkpoint_path,
Expand All @@ -124,7 +135,7 @@ def from_args(cls, args): # -> BuilderArgs:
dso_path=args.dso_path,
pte_path=args.pte_path,
device=args.device,
precision=name_to_dtype(args.dtype),
precision=dtype,
setup_caches=(args.output_dso_path or args.output_pte_path),
use_tp=False,
is_chat_model=is_chat_model,
Expand Down
85 changes: 82 additions & 3 deletions build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,61 @@
import logging
import os
from pathlib import Path
from typing import Dict, List

##########################################################################
### unpack packed weights ###

from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F


def unpack_packed_weights(
packed_weights: Dict[str, Any],
packed_linear: Callable,
input_dtype: torch.dtype,
unpacked_dims: Tuple,
) -> torch.Tensor:
"""Given a packed weight matrix `packed_weights`, a Callable
implementing a packed linear function for the packed format, and the
unpacked dimensions of the weights, recreate the unpacked weight
matrix. In addition to the packed weights, as a dictionary to specify
whatever arguments the packed routine expects, we also need the input
data type because packing may depend on input dtype, or only some
input dtypes may be supported. We also need the dimensions of the
unpacked matrix. At present, this does not handle padding, but that will
be straightforward to add. Similarly, the same approach can be used
for both linear and mm operators.

Args:
packed_weights: Dict[str, Any],
packed_linear: Callable,
input_dtype: torch.dtype,
unpacked_dims: Optional[Tuple]=None

Example usage:
packed_weights = {
"weight" : weight_int4pack,
"qGroupSize": groupsize,
"scales_and_zeros": scales_and_zeros
}
unpacked_weights = unpack_packed_weights(
_weight_int4pack_linear,
packed_weights,
torch.bfloat6,
(256, 1024),
)


"""
assert len(unpacked_dims) == 2, "unpacked_dims must be a tuple of length 2"
cols = unpacked_dims[1]

unpacked_weights = packed_linear(
torch.eye(cols, dtype=input_dtype), **packed_weights
).transpose(0, 1)
return unpacked_weights


##########################################################################
Expand Down Expand Up @@ -78,7 +130,17 @@ def get_precision():

##########################################################################
### dtype name to torch.dtype mapping ###


def name_to_dtype(name):
if (name == "fast") or (name == "fast16"):
import platform

if platform.processor() == "arm":
return torch.float16
else:
return torch.bfloat16

if name in name_to_dtype_dict:
return name_to_dtype_dict[name]
else:
Expand All @@ -98,6 +160,8 @@ def allowable_dtype_names() -> List[str]:
"float32": torch.float,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"fast": None,
"fast16": None,
}


Expand Down Expand Up @@ -157,12 +221,27 @@ def state_dict_device(d, device="cpu") -> Dict:
### move state dict to specified device ###


def is_mps_available() -> bool:
if not torch.backends.mps.is_available():
return False

# out system says mps is available, but it's not on VMs
# so let's set up some memry, and see if that work:
try:
mps_tensor = torch.zero(1024, dtype=torch.float16, device="mps")
except:
return False

# MPS, is that you?
return True


def get_device_str(device) -> str:
if isinstance(device, str) and device == "fast":
return (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
else "mps" if is_mps_available() else "cpu"
)
else:
return str(device)
Expand All @@ -173,6 +252,6 @@ def get_device(device) -> str:
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
else "mps" if is_mps_available() else "cpu"
)
return torch.device(device)
13 changes: 6 additions & 7 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from build.utils import allowable_dtype_names, allowable_params_table, get_device_str
from download import download_and_convert, is_model_downloaded

default_device = "cpu"
default_device = "fast"


# Handle CLI arguments that are common to a majority of subcommands.
Expand Down Expand Up @@ -136,12 +136,12 @@ def _add_arguments_common(parser):
parser.add_argument(
"--compile-prefill",
action="store_true",
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times. (Requires `--parallel-prefill`)",
help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.",
)
parser.add_argument(
"--parallel-prefill",
"--sequential-prefill",
action="store_true",
help="Whether to perform prefill in parallel, or one token at a time. Improves prefill perf. DSO and PTE models presently do not support parallel prefill.",
help="Whether to perform prefill sequentially. Only used for model debug.",
)
parser.add_argument(
"--profile",
Expand Down Expand Up @@ -210,11 +210,10 @@ def _add_arguments_common(parser):
help="Use the specified ExecuTorch .pte model file",
)
parser.add_argument(
"-d",
"--dtype",
default="float32",
default="fast",
choices=allowable_dtype_names(),
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32",
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32, fast16, fast",
)
parser.add_argument(
"-v",
Expand Down
2 changes: 1 addition & 1 deletion export_et_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
super().__init__()

dtype = torch.float

# This is flipped around from what is in build.model's KVCache
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
self.register_buffer(
Expand Down
Loading