|
15 | 15 | import torch
|
16 | 16 | import torch.nn as nn
|
17 | 17 | import torch.nn.functional as F
|
18 |
| -from build.utils import find_multiple, get_precision, use_et_backend |
| 18 | +from build.utils import find_multiple, get_precision |
19 | 19 |
|
20 | 20 |
|
21 | 21 | #########################################################################
|
@@ -92,6 +92,30 @@ def quantized_model(self) -> nn.Module:
|
92 | 92 | return self.quantizer.quantize(self.model_)
|
93 | 93 |
|
94 | 94 |
|
| 95 | +######################################################################### |
| 96 | +### QuantHandler API definition ### |
| 97 | +### (unify with torchao in future) ### |
| 98 | + |
| 99 | + |
| 100 | +class QuantHandler: |
| 101 | + def __init__(self, model: nn.Module, device="cpu", tokenizer=None): |
| 102 | + self.model_ = model |
| 103 | + self.device = device |
| 104 | + self.tokenizer = tokenizer |
| 105 | + |
| 106 | + def create_quantized_state_dict(self) -> Dict: # "StateDict" |
| 107 | + pass |
| 108 | + |
| 109 | + def convert_for_runtime(self) -> nn.Module: |
| 110 | + pass |
| 111 | + |
| 112 | + def quantized_model(self) -> nn.Module: |
| 113 | + model_updated_state_dict = self.create_quantized_state_dict() |
| 114 | + self.convert_for_runtime() |
| 115 | + self.model_.load_state_dict(model_updated_state_dict) |
| 116 | + return self.model_ |
| 117 | + |
| 118 | + |
95 | 119 | #########################################################################
|
96 | 120 | ### wrapper for setting precision as a QuantHandler ###
|
97 | 121 |
|
@@ -623,12 +647,6 @@ def __init__(
|
623 | 647 | self.groupsize = groupsize
|
624 | 648 | self.dtype = dtype
|
625 | 649 | self.packed = packed
|
626 |
| - |
627 |
| - if use_et_backend(): |
628 |
| - self.forward = self.et_forward |
629 |
| - else: |
630 |
| - self.forward = self.aoti_forward |
631 |
| - |
632 | 650 | if not packed:
|
633 | 651 | self.register_buffer(
|
634 | 652 | "weight",
|
@@ -657,18 +675,12 @@ def __init__(
|
657 | 675 | )
|
658 | 676 |
|
659 | 677 | @torch.no_grad()
|
660 |
| - def et_forward(self, indices: torch.Tensor) -> torch.Tensor: |
661 |
| - if self.packed: |
662 |
| - return torch.ops.quantized_decomposed.embedding_byte.dtype( |
663 |
| - self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype |
664 |
| - ) |
665 |
| - else: |
666 |
| - return torch.ops.quantized_decomposed.embedding_4bit.dtype( |
| 678 | + def forward(self, indices: torch.Tensor) -> torch.Tensor: |
| 679 | + if False: # Used for Executorch |
| 680 | + return torch.ops.llama_quantized.embedding_byte.dtype( |
667 | 681 | self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
|
668 | 682 | )
|
669 | 683 |
|
670 |
| - @torch.no_grad() |
671 |
| - def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor: |
672 | 684 | # result_weights = self.weight.index_select(0, indices.view(-1))
|
673 | 685 | # result_scales = self.scales.index_select(0, indices.view(-1))
|
674 | 686 |
|
|
0 commit comments