Skip to content

implement --device 'fast' #511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
import os
from pathlib import Path
from typing import List, Dict
from typing import Dict, List

import torch

Expand Down Expand Up @@ -143,10 +143,36 @@ def canonical_path(path):


#########################################################################
### general utility functions ###
### move state dict to specified device ###


def state_dict_device(d, device = "cpu") -> Dict:
def state_dict_device(d, device="cpu") -> Dict:
for key, weight in d.items():
d[key] = weight.to(device=device)

return d


#########################################################################
### move state dict to specified device ###


def get_device_str(device) -> str:
if isinstance(device, str) and device == "fast":
return (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
else:
return str(device)


def get_device(device) -> str:
if isinstance(device, str) and device == "fast":
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
return torch.device(device)
9 changes: 6 additions & 3 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from build.utils import allowable_dtype_names, allowable_params_table
from build.utils import allowable_dtype_names, allowable_params_table, get_device_str
from download import download_and_convert, is_model_downloaded

default_device = "cpu"
Expand Down Expand Up @@ -252,7 +252,7 @@ def _add_arguments_common(parser):
"--device",
type=str,
default=default_device,
choices=["cpu", "cuda", "mps"],
choices=["fast", "cpu", "cuda", "mps"],
help="Hardware device to use. Options: cpu, cuda, mps",
)
parser.add_argument(
Expand Down Expand Up @@ -300,7 +300,7 @@ def arg_init(args):
"You are using PyTorch {torch.__version__}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
)

if hasattr(args, 'quantize') and Path(args.quantize).is_file():
if hasattr(args, "quantize") and Path(args.quantize).is_file():
with open(args.quantize, "r") as f:
args.quantize = json.loads(f.read())

Expand All @@ -309,6 +309,9 @@ def arg_init(args):

# if we specify dtype in quantization recipe, replicate it as args.dtype
args.dtype = args.quantize.get("precision", {}).get("dtype", args.dtype)
args.device = get_device_str(
args.quantize.get("executor", {}).get("accelerator", args.device)
)

if hasattr(args, "seed") and args.seed:
torch.manual_seed(args.seed)
Expand Down
40 changes: 34 additions & 6 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from build.utils import find_multiple, get_precision, name_to_dtype, use_et_backend, state_dict_device
from build.utils import (
find_multiple,
get_precision,
name_to_dtype,
state_dict_device,
use_et_backend,
)


#########################################################################
Expand Down Expand Up @@ -116,6 +122,28 @@ def quantized_model(self) -> nn.Module:
return self.model_.to(device=self.device, dtype=self.dtype)


#########################################################################
### wrapper for setting device as a QuantHandler ###


class ExecutorHandler(QuantHandler):
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, accelerator):
self.model_ = model

if isinstance(accelerator, str):
device = get_device_str(accelerator)
self.device = device

def create_quantized_state_dict(self) -> Dict: # "StateDict"
pass

def convert_for_runtime(self) -> nn.Module:
pass

def quantized_model(self) -> nn.Module:
return self.model_.to(device=self.device)


#########################################################################
##### Quantization Primitives ######

Expand Down Expand Up @@ -407,8 +435,8 @@ def __init__(
@torch.no_grad()
def create_quantized_state_dict(self) -> Dict:
cur_state_dict = state_dict_device(self.model_.state_dict())
dict_device = "cpu" # self.device
dict_device = "cpu" # self.device

if self.bitwidth == 4:
range_min = -8
range_max = 7
Expand Down Expand Up @@ -824,12 +852,11 @@ def __init__(
assert groupsize in [32, 64, 128, 256]
assert inner_k_tiles in [2, 4, 8]


# @torch.no_grad()
# def p(self):
# cur_state_dict = state_dict_device(self.model_.state_dict())
# dict_device = "cpu" # self.device
#
#
# for fqn, mod in self.model_.named_modules():
# if hasattr(mod, "weight"):
# print(f"device={str(mod.weight.data.device)}")
Expand All @@ -838,7 +865,7 @@ def __init__(
def create_quantized_state_dict(self):
cur_state_dict = state_dict_device(self.model_.state_dict())
dict_device = "cpu" # self.device

for fqn, mod in self.model_.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
Expand Down Expand Up @@ -1282,4 +1309,5 @@ def quantized_model(self) -> nn.Module:
"linear:int4-gptq": WeightOnlyInt4GPTQQuantHandler,
"linear:hqq": WeightOnlyInt4HqqQuantHandler,
"precision": PrecisionHandler,
"executor": ExecutorHandler,
}