|
15 | 15 | import torch
|
16 | 16 | import torch.nn as nn
|
17 | 17 | import torch.nn.functional as F
|
18 |
| -from build.utils import find_multiple, get_precision |
| 18 | +from build.utils import find_multiple, get_precision, use_et_backend |
19 | 19 |
|
20 | 20 |
|
21 | 21 | #########################################################################
|
@@ -92,30 +92,6 @@ def quantized_model(self) -> nn.Module:
|
92 | 92 | return self.quantizer.quantize(self.model_)
|
93 | 93 |
|
94 | 94 |
|
95 |
| -######################################################################### |
96 |
| -### QuantHandler API definition ### |
97 |
| -### (unify with torchao in future) ### |
98 |
| - |
99 |
| - |
100 |
| -class QuantHandler: |
101 |
| - def __init__(self, model: nn.Module, device="cpu", tokenizer=None): |
102 |
| - self.model_ = model |
103 |
| - self.device = device |
104 |
| - self.tokenizer = tokenizer |
105 |
| - |
106 |
| - def create_quantized_state_dict(self) -> Dict: # "StateDict" |
107 |
| - pass |
108 |
| - |
109 |
| - def convert_for_runtime(self) -> nn.Module: |
110 |
| - pass |
111 |
| - |
112 |
| - def quantized_model(self) -> nn.Module: |
113 |
| - model_updated_state_dict = self.create_quantized_state_dict() |
114 |
| - self.convert_for_runtime() |
115 |
| - self.model_.load_state_dict(model_updated_state_dict) |
116 |
| - return self.model_ |
117 |
| - |
118 |
| - |
119 | 95 | #########################################################################
|
120 | 96 | ### wrapper for setting precision as a QuantHandler ###
|
121 | 97 |
|
@@ -647,6 +623,12 @@ def __init__(
|
647 | 623 | self.groupsize = groupsize
|
648 | 624 | self.dtype = dtype
|
649 | 625 | self.packed = packed
|
| 626 | + |
| 627 | + if use_et_backend(): |
| 628 | + self.forward = self.et_forward |
| 629 | + else: |
| 630 | + self.forward = self.aoti_forward |
| 631 | + |
650 | 632 | if not packed:
|
651 | 633 | self.register_buffer(
|
652 | 634 | "weight",
|
@@ -675,12 +657,18 @@ def __init__(
|
675 | 657 | )
|
676 | 658 |
|
677 | 659 | @torch.no_grad()
|
678 |
| - def forward(self, indices: torch.Tensor) -> torch.Tensor: |
679 |
| - if False: # Used for Executorch |
680 |
| - return torch.ops.llama_quantized.embedding_byte.dtype( |
| 660 | + def et_forward(self, indices: torch.Tensor) -> torch.Tensor: |
| 661 | + if self.packed: |
| 662 | + return torch.ops.quantized_decomposed.embedding_byte.dtype( |
| 663 | + self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype |
| 664 | + ) |
| 665 | + else: |
| 666 | + return torch.ops.quantized_decomposed.embedding_4bit.dtype( |
681 | 667 | self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
|
682 | 668 | )
|
683 | 669 |
|
| 670 | + @torch.no_grad() |
| 671 | + def aoti_forward(self, indices: torch.Tensor) -> torch.Tensor: |
684 | 672 | # result_weights = self.weight.index_select(0, indices.view(-1))
|
685 | 673 | # result_scales = self.scales.index_select(0, indices.view(-1))
|
686 | 674 |
|
|
0 commit comments