Skip to content

do weight transform on cpu #508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
import os
from pathlib import Path
from typing import List
from typing import List, Dict

import torch

Expand Down Expand Up @@ -133,10 +133,20 @@ def device_sync(device="cpu"):


#########################################################################
### general utilkity functions ###
### general utility functions ###


# in fbcode, we can intercept certain local paths that
# should be interpreted as part of an XAR package
def canonical_path(path):
return path


#########################################################################
### general utility functions ###

def state_dict_device(d, device = "cpu") -> Dict:
for key, weight in d.items():
d[key] = weight.to(device=device)

return d
39 changes: 27 additions & 12 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from build.utils import find_multiple, get_precision, name_to_dtype, use_et_backend
from build.utils import find_multiple, get_precision, name_to_dtype, use_et_backend, state_dict_device


#########################################################################
Expand Down Expand Up @@ -63,7 +63,7 @@ def convert_for_runtime(self) -> nn.Module:
pass

def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
model_updated_state_dict = state_dict_device(self.create_quantized_state_dict())
self.convert_for_runtime()
self.model_.load_state_dict(model_updated_state_dict)
return self.model_
Expand Down Expand Up @@ -406,8 +406,9 @@ def __init__(

@torch.no_grad()
def create_quantized_state_dict(self) -> Dict:
cur_state_dict = self.model_.state_dict()

cur_state_dict = state_dict_device(self.model_.state_dict())
dict_device = "cpu" # self.device

if self.bitwidth == 4:
range_min = -8
range_max = 7
Expand Down Expand Up @@ -446,8 +447,8 @@ def create_quantized_state_dict(self) -> Dict:
scales_dtype=mod.weight.dtype,
)

weight = weight.to(device=self.device)
scales = scales.to(device=self.device)
weight = weight.to(device=dict_device)
scales = scales.to(device=dict_device)
cur_state_dict[f"{fqn}.weight"] = weight
# squeeze makes groupsize=rowsize unidimensional
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
Expand Down Expand Up @@ -553,7 +554,8 @@ def __init__(

@torch.no_grad()
def create_quantized_state_dict(self) -> Dict:
cur_state_dict = self.model_.state_dict()
cur_state_dict = state_dict_device(self.model_.state_dict())
dict_device = "cpu" # self.device

if self.bitwidth == 4:
range_min = -8
Expand Down Expand Up @@ -595,8 +597,8 @@ def create_quantized_state_dict(self) -> Dict:
weight_packed = weight_even + weight_odd
weight = weight_packed

weight = weight.to(device=self.device)
scales = scales.to(device=self.device)
weight = weight.to(device=dict_device)
scales = scales.to(device=dict_device)
# Update state dict
cur_state_dict[f"{fqn}.weight"] = weight
# squeeze makes groupsize=rowsize unidimensional
Expand Down Expand Up @@ -822,9 +824,21 @@ def __init__(
assert groupsize in [32, 64, 128, 256]
assert inner_k_tiles in [2, 4, 8]


# @torch.no_grad()
# def p(self):
# cur_state_dict = state_dict_device(self.model_.state_dict())
# dict_device = "cpu" # self.device
#
# for fqn, mod in self.model_.named_modules():
# if hasattr(mod, "weight"):
# print(f"device={str(mod.weight.data.device)}")

@torch.no_grad()
def create_quantized_state_dict(self):
cur_state_dict = self.model_.state_dict()
cur_state_dict = state_dict_device(self.model_.state_dict())
dict_device = "cpu" # self.device

for fqn, mod in self.model_.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
Expand Down Expand Up @@ -856,8 +870,8 @@ def create_quantized_state_dict(self):
weight.to(torch.float), self.groupsize, self.inner_k_tiles
)
)
weight_int4pack = weight_int4pack.to(device=self.device)
scales_and_zeros = scales_and_zeros.to(device=self.device)
weight_int4pack = weight_int4pack.to(device=dict_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so cpu packed weight and cuda packed weight are calling different int4mm kernels, and the weights prepared in one device may not be compatible with another (gives wrong results) as recently discovered by @HDCharles, and it's a silent error right now. have we done any evaluation on accuracy for this change (on cuda)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have not done this, but @malfet and I have discussed this, and Intel had previously promised us an unpack routine - so we would be bale to unpack() the different formats. BTW, we need an unpack for the GPU packing format as well

scales_and_zeros = scales_and_zeros.to(device=dict_device)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros

Expand All @@ -877,6 +891,7 @@ def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.model_.load_state_dict(model_updated_state_dict)
# self.p()
return self.model_


Expand Down