Skip to content

Commit 962ec0d

Browse files
committed
[AOTI] Remove the original model weights in Python deployment
Summary: Fixes #1302. Because AOTI-compiled model contains a copy of model weights, we need to release the corresponding eager model weights in the Python deployment path.
1 parent f20f5e7 commit 962ec0d

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

torchchat/cli/builder.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,19 @@ def _initialize_model(
544544
# attributes will NOT be seen on by AOTI-compiled forward
545545
# function, e.g. calling model.setup_cache will NOT touch
546546
# AOTI compiled and maintained model buffers such as kv_cache.
547+
# Using cpp runner to run AOTI compiled model is recommended.
548+
#
549+
# Released the loaded model to free up device memory.
550+
# The AOTI-compiled model contains a copy of the model weights.
551+
model.model = None
552+
import gc
553+
gc.collect()
554+
torch.cuda.empty_cache()
555+
556+
def do_nothing(max_batch_size, max_seq_length):
557+
pass
558+
model.setup_caches = do_nothing
559+
547560
model.forward = torch._export.aot_load(
548561
str(builder_args.dso_path.absolute()), builder_args.device
549562
)

0 commit comments

Comments
 (0)