File tree Expand file tree Collapse file tree 2 files changed +0
-10
lines changed Expand file tree Collapse file tree 2 files changed +0
-10
lines changed Original file line number Diff line number Diff line change 18
18
from executorch .examples .models .llama .llama_transformer import Transformer
19
19
20
20
from executorch .examples .models .llama .model_args import ModelArgs
21
- from torchao .utils import TorchAOBaseTensor
22
21
23
22
try :
24
23
from .fairseq2 import convert_to_llama_checkpoint
@@ -258,9 +257,6 @@ def __init__(self, **kwargs):
258
257
strict = False ,
259
258
assign = True ,
260
259
) # self.model_ = Transformer(gptconf)
261
- for param in self .model_ .parameters ():
262
- if isinstance (param , TorchAOBaseTensor ):
263
- param .requires_grad = False
264
260
else :
265
261
print ("Checkpoint not provided, defaulting weights to zeros." )
266
262
self .model_ .to_empty (device = "cpu" )
Original file line number Diff line number Diff line change 41
41
from torch .ao .quantization .quantizer .composable_quantizer import ComposableQuantizer
42
42
from torch .export import export_for_training , ExportedProgram
43
43
from torch .nn .attention import SDPBackend
44
- from torchao .utils import unwrap_tensor_subclass
45
44
46
45
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
47
46
logging .basicConfig (level = logging .INFO , format = FORMAT )
@@ -200,11 +199,6 @@ def _get_edge_config(self) -> EdgeCompileConfig:
200
199
return edge_config
201
200
202
201
def _export (self , module : Optional [torch .nn .Module ] = None ) -> ExportedProgram :
203
- if module is not None :
204
- unwrap_tensor_subclass (module )
205
- else :
206
- unwrap_tensor_subclass (self .model )
207
-
208
202
dynamic_shape = self ._get_dynamic_shape ()
209
203
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
210
204
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
You can’t perform that action at this time.
0 commit comments