Skip to content

Commit 9697399

Browse files
mikekgfbmalfet
authored andcommitted
do weight transform on cpu (#508)
1 parent be43ce9 commit 9697399

File tree

2 files changed

+39
-14
lines changed

2 files changed

+39
-14
lines changed

build/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import logging
1010
import os
1111
from pathlib import Path
12-
from typing import List
12+
from typing import List, Dict
1313

1414
import torch
1515

@@ -133,10 +133,20 @@ def device_sync(device="cpu"):
133133

134134

135135
#########################################################################
136-
### general utilkity functions ###
136+
### general utility functions ###
137137

138138

139139
# in fbcode, we can intercept certain local paths that
140140
# should be interpreted as part of an XAR package
141141
def canonical_path(path):
142142
return path
143+
144+
145+
#########################################################################
146+
### general utility functions ###
147+
148+
def state_dict_device(d, device = "cpu") -> Dict:
149+
for key, weight in d.items():
150+
d[key] = weight.to(device=device)
151+
152+
return d

quantize.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
import torch.nn as nn
1717
import torch.nn.functional as F
18-
from build.utils import find_multiple, get_precision, name_to_dtype, use_et_backend
18+
from build.utils import find_multiple, get_precision, name_to_dtype, use_et_backend, state_dict_device
1919

2020

2121
#########################################################################
@@ -63,7 +63,7 @@ def convert_for_runtime(self) -> nn.Module:
6363
pass
6464

6565
def quantized_model(self) -> nn.Module:
66-
model_updated_state_dict = self.create_quantized_state_dict()
66+
model_updated_state_dict = state_dict_device(self.create_quantized_state_dict())
6767
self.convert_for_runtime()
6868
self.model_.load_state_dict(model_updated_state_dict)
6969
return self.model_
@@ -406,8 +406,9 @@ def __init__(
406406

407407
@torch.no_grad()
408408
def create_quantized_state_dict(self) -> Dict:
409-
cur_state_dict = self.model_.state_dict()
410-
409+
cur_state_dict = state_dict_device(self.model_.state_dict())
410+
dict_device = "cpu" # self.device
411+
411412
if self.bitwidth == 4:
412413
range_min = -8
413414
range_max = 7
@@ -446,8 +447,8 @@ def create_quantized_state_dict(self) -> Dict:
446447
scales_dtype=mod.weight.dtype,
447448
)
448449

449-
weight = weight.to(device=self.device)
450-
scales = scales.to(device=self.device)
450+
weight = weight.to(device=dict_device)
451+
scales = scales.to(device=dict_device)
451452
cur_state_dict[f"{fqn}.weight"] = weight
452453
# squeeze makes groupsize=rowsize unidimensional
453454
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
@@ -553,7 +554,8 @@ def __init__(
553554

554555
@torch.no_grad()
555556
def create_quantized_state_dict(self) -> Dict:
556-
cur_state_dict = self.model_.state_dict()
557+
cur_state_dict = state_dict_device(self.model_.state_dict())
558+
dict_device = "cpu" # self.device
557559

558560
if self.bitwidth == 4:
559561
range_min = -8
@@ -595,8 +597,8 @@ def create_quantized_state_dict(self) -> Dict:
595597
weight_packed = weight_even + weight_odd
596598
weight = weight_packed
597599

598-
weight = weight.to(device=self.device)
599-
scales = scales.to(device=self.device)
600+
weight = weight.to(device=dict_device)
601+
scales = scales.to(device=dict_device)
600602
# Update state dict
601603
cur_state_dict[f"{fqn}.weight"] = weight
602604
# squeeze makes groupsize=rowsize unidimensional
@@ -822,9 +824,21 @@ def __init__(
822824
assert groupsize in [32, 64, 128, 256]
823825
assert inner_k_tiles in [2, 4, 8]
824826

827+
828+
# @torch.no_grad()
829+
# def p(self):
830+
# cur_state_dict = state_dict_device(self.model_.state_dict())
831+
# dict_device = "cpu" # self.device
832+
#
833+
# for fqn, mod in self.model_.named_modules():
834+
# if hasattr(mod, "weight"):
835+
# print(f"device={str(mod.weight.data.device)}")
836+
825837
@torch.no_grad()
826838
def create_quantized_state_dict(self):
827-
cur_state_dict = self.model_.state_dict()
839+
cur_state_dict = state_dict_device(self.model_.state_dict())
840+
dict_device = "cpu" # self.device
841+
828842
for fqn, mod in self.model_.named_modules():
829843
if isinstance(mod, torch.nn.Linear):
830844
assert not mod.bias
@@ -856,8 +870,8 @@ def create_quantized_state_dict(self):
856870
weight.to(torch.float), self.groupsize, self.inner_k_tiles
857871
)
858872
)
859-
weight_int4pack = weight_int4pack.to(device=self.device)
860-
scales_and_zeros = scales_and_zeros.to(device=self.device)
873+
weight_int4pack = weight_int4pack.to(device=dict_device)
874+
scales_and_zeros = scales_and_zeros.to(device=dict_device)
861875
cur_state_dict[f"{fqn}.weight"] = weight_int4pack
862876
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
863877

@@ -877,6 +891,7 @@ def quantized_model(self) -> nn.Module:
877891
model_updated_state_dict = self.create_quantized_state_dict()
878892
self.convert_for_runtime()
879893
self.model_.load_state_dict(model_updated_state_dict)
894+
# self.p()
880895
return self.model_
881896

882897

0 commit comments

Comments
 (0)