Skip to content

Commit 4a7dab8

Browse files
[AOTI] Remove the original model weights in Python deployment (#1337)
* [AOTI] Remove the original model weights in Python deployment Summary: Fixes #1302. Because AOTI-compiled model contains a copy of model weights, we need to release the corresponding eager model weights in the Python deployment path. * Revert "[AOTI] Remove the original model weights in Python deployment" This reverts commit 962ec0d. * Refactor the code * Add setup_cache for aoti_package_path --------- Co-authored-by: Jack-Khuu <[email protected]>
1 parent 54455a3 commit 4a7dab8

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

torchchat/cli/builder.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,15 @@ def _load_model(builder_args: BuilderArgs) -> Model:
536536
model = _load_model_default(builder_args)
537537
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
538538

539+
if builder_args.dso_path or builder_args.aoti_package_path:
540+
# AOTI-compoiled model will load its own weights.
541+
# Release weights here to avoid OOM
542+
import gc
543+
if hasattr(model, "model"):
544+
model.model = None
545+
gc.collect()
546+
torch.cuda.empty_cache()
547+
539548
model = model.to(device=builder_args.device, dtype=builder_args.precision)
540549
return model.eval()
541550

@@ -584,6 +593,12 @@ def _initialize_model(
584593
# attributes will NOT be seen on by AOTI-compiled forward
585594
# function, e.g. calling model.setup_cache will NOT touch
586595
# AOTI compiled and maintained model buffers such as kv_cache.
596+
# Using cpp runner to run AOTI compiled model is recommended.
597+
598+
def do_nothing(max_batch_size, max_seq_length):
599+
pass
600+
model.setup_caches = do_nothing
601+
587602
model.forward = torch._export.aot_load(
588603
str(builder_args.dso_path.absolute()), builder_args.device
589604
)
@@ -617,6 +632,11 @@ def _initialize_model(
617632
aoti_compiled_model = load_package(
618633
str(builder_args.aoti_package_path.absolute())
619634
)
635+
636+
def do_nothing(max_batch_size, max_seq_length):
637+
pass
638+
model.setup_caches = do_nothing
639+
620640
model.forward = aoti_compiled_model
621641
metadata = aoti_compiled_model.get_metadata()
622642
builder_args.device = metadata["AOTI_DEVICE_KEY"]

0 commit comments

Comments
 (0)