Skip to content

special quant logic for et #393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,25 @@ def _initialize_model(
builder_args,
quantize,
tokenizer=None,
is_et=None,
):
# Infer is_et from dso_path or pte_path if is_et is not explicitly provided
if is_et is None:
is_dso = builder_args.dso_path is not None
is_pte = builder_args.pte_path is not None
assert is_dso or is_pte
assert not (is_dso and is_pte)

is_et = is_pte

# Assert is_et is consistent with dso_path or pte_path, if provided
is_dso = builder_args.dso_path is not None
is_pte = builder_args.pte_path is not None
assert not (is_dso and is_pte)

assert not (is_pte and not is_et)
assert not (is_dso and is_et)

print("Loading model ...")
t0 = time.time()

Expand Down Expand Up @@ -370,7 +388,7 @@ def _initialize_model(
if quantize:
t0q = time.time()
print(f"Quantizing the model with: {quantize}")
quantize_model(model, builder_args.device, quantize, tokenizer)
quantize_model(model, builder_args.device, quantize, tokenizer, is_et)
device_sync(device=builder_args.device)
print(f"Time to quantize model: {time.time() - t0q:.02f} seconds")

Expand Down
48 changes: 20 additions & 28 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,39 +52,31 @@ def main(args):
output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path

# TODO: clean this up
# This mess is because ET does not support _weight_int4pack_mm right now
if not builder_args.gguf_path:
# tokenizer needed for quantization so get that here,
try:
tokenizer_args = TokenizerArgs.from_args(args)
tokenizer = _initialize_tokenizer(tokenizer_args)
except:
tokenizer = None

model = _initialize_model(
try:
tokenizer_args = TokenizerArgs.from_args(args)
tokenizer = _initialize_tokenizer(tokenizer_args)
except:
tokenizer = None

if output_pte_path:
_set_gguf_kwargs(builder_args, is_et=True, context="export")
model_to_pte = _initialize_model(
builder_args,
quantize,
tokenizer,
is_et=True,
)
model_to_pte = model
model_to_dso = model
else:
if output_pte_path:
_set_gguf_kwargs(builder_args, is_et=True, context="export")
model_to_pte = _initialize_model(
builder_args,
quantize,
)
_unset_gguf_kwargs(builder_args)
_unset_gguf_kwargs(builder_args)

if output_dso_path:
_set_gguf_kwargs(builder_args, is_et=False, context="export")
model_to_dso = _initialize_model(
builder_args,
quantize,
)
_unset_gguf_kwargs(builder_args)
if output_dso_path:
_set_gguf_kwargs(builder_args, is_et=False, context="export")
model_to_dso = _initialize_model(
builder_args,
quantize,
tokenizer,
is_et=False,
)
_unset_gguf_kwargs(builder_args)

with torch.no_grad():
if output_pte_path:
Expand Down
33 changes: 25 additions & 8 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
### torchchat quantization API ###


def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None, is_et=False):
"""
Quantize the specified model using the quantizers described by
a quantization dict of the form:
Expand All @@ -37,6 +37,9 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
quantize_options = json.loads(quantize_options)

for quantizer, q_kwargs in quantize_options.items():
q_kwargs["is_et"] = is_et


if quantizer not in quantizer_class_dict:
raise RuntimeError(f"unknown quantizer {quantizer} specified")

Expand Down Expand Up @@ -520,7 +523,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:


def replace_embedding_weight_only_grouped_int8_per_channel(
module, device, bitwidth: int = 8, groupsize: Optional[int] = None, packed=False
module, device, bitwidth: int = 8, groupsize: Optional[int] = None, packed=False, is_et=False
):
for name, child in module.named_children():
# print(f"name: {name}")
Expand All @@ -536,11 +539,12 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
embedding_dim=child.weight.shape[1],
groupsize=groupsize,
packed=packed,
is_et=is_et,
),
)
else:
replace_embedding_weight_only_grouped_int8_per_channel(
child, device, bitwidth, groupsize, packed
child, device, bitwidth, groupsize, packed, is_et
)


Expand All @@ -554,6 +558,7 @@ def __init__(
bitwidth: int = 8,
groupsize: Optional[int] = None,
packed=True,
is_et=False,
):
# when quantization dictionary comes from JSON, packed is a string
if isinstance(packed, str):
Expand All @@ -563,6 +568,7 @@ def __init__(
self.groupsize = groupsize
self.bitwidth = bitwidth
self.packed = packed
self.is_et = is_et

@torch.no_grad()
def create_quantized_state_dict(self, packed=False) -> Dict:
Expand Down Expand Up @@ -619,7 +625,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict:

def convert_for_runtime(self) -> nn.Module:
replace_embedding_weight_only_grouped_int8_per_channel(
self.model_, self.device, self.bitwidth, self.groupsize, self.packed
self.model_, self.device, self.bitwidth, self.groupsize, self.packed, self.is_et
)
return self.model_

Expand All @@ -639,13 +645,15 @@ def __init__(
groupsize: Optional[int] = None,
dtype=torch.half,
packed=False,
is_et=False,
) -> None:
super().__init__()
if groupsize is None or groupsize == 0:
groupsize = embedding_dim
self.groupsize = groupsize
self.dtype = dtype
self.packed = packed
self.is_et = is_et
if not packed:
self.register_buffer(
"weight",
Expand Down Expand Up @@ -675,10 +683,19 @@ def __init__(

@torch.no_grad()
def forward(self, indices: torch.Tensor) -> torch.Tensor:
if False: # Used for Executorch
return torch.ops.llama_quantized.embedding_byte.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)

# et-path
if self.is_et:
if not self.packed: # 8bit
return torch.ops.quantized_decomposed.embedding_byte.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)
else: # 4bit packed
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)

# non-et path

# result_weights = self.weight.index_select(0, indices.view(-1))
# result_scales = self.scales.index_select(0, indices.view(-1))
Expand Down