-
Notifications
You must be signed in to change notification settings - Fork 250
do weight transform on cpu #508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
8e60a72
to
59e6275
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works...
59e6275
to
fe3e608
Compare
@@ -856,8 +870,8 @@ def create_quantized_state_dict(self): | |||
weight.to(torch.float), self.groupsize, self.inner_k_tiles | |||
) | |||
) | |||
weight_int4pack = weight_int4pack.to(device=self.device) | |||
scales_and_zeros = scales_and_zeros.to(device=self.device) | |||
weight_int4pack = weight_int4pack.to(device=dict_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so cpu packed weight and cuda packed weight are calling different int4mm kernels, and the weights prepared in one device may not be compatible with another (gives wrong results) as recently discovered by @HDCharles, and it's a silent error right now. have we done any evaluation on accuracy for this change (on cuda)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have not done this, but @malfet and I have discussed this, and Intel had previously promised us an unpack routine - so we would be bale to unpack() the different formats. BTW, we need an unpack for the GPU packing format as well
to avoid OOM situations, do weight transform on cpu
load_state_dict does not have a map_location argument lich torch.load().
However, a quick check for mps suggested that when the state dict is loaded back into a model, it's placed on the device of the pre-existing device. If this does not work as expected in scenarios, it should be quick to write a new version of load_state_dict that takes a map_location or amend the APUI for the exixsting function....
We might also avoid instantiating the state dict altogether - this is inherited from gpt fast, to enable sharing of quantization algos between torchchat and gpt-fast.