Skip to content

Commit 5aed7ae

Browse files
[AOTI] Add a --max-seq-length option for export (#1018)
Summary: This improves best tokens/sec from 73 to 85. Co-authored-by: Jack-Khuu <[email protected]>
1 parent ce41944 commit 5aed7ae

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

build/builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class BuilderArgs:
4545
is_chat_model: bool = False
4646
prefill_possible: bool = False
4747
dynamic_shapes: bool = False
48+
max_seq_length: Optional[int] = None
4849

4950
def __post_init__(self):
5051
if self.device is None:
@@ -159,6 +160,7 @@ def from_args(cls, args): # -> BuilderArgs:
159160
use_distributed=args.distributed,
160161
is_chat_model=is_chat_model,
161162
dynamic_shapes=getattr(args, "dynamic_shapes", False),
163+
max_seq_length=getattr(args, "max_seq_length", None),
162164
)
163165

164166
@classmethod
@@ -437,6 +439,7 @@ def _initialize_model(
437439
builder_args,
438440
quantize,
439441
tokenizer=None,
442+
max_seq_length=None,
440443
):
441444
print("Loading model...")
442445

@@ -513,7 +516,7 @@ def _initialize_model(
513516
if builder_args.setup_caches:
514517
with torch.device(builder_args.device):
515518
model.setup_caches(
516-
max_batch_size=1, max_seq_length=model.config.max_seq_length
519+
max_batch_size=1, max_seq_length=max_seq_length or model.config.max_seq_length
517520
)
518521

519522
model.to(dtype=builder_args.precision)

cli.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def add_arguments_for_verb(parser, verb: str) -> None:
6868
_add_generation_args(parser, verb)
6969
if verb == "export":
7070
_add_export_output_path_args(parser)
71+
_add_export_args(parser)
7172
if verb == "eval":
7273
_add_exported_input_path_args(parser)
7374
_add_evaluation_args(parser)
@@ -185,11 +186,20 @@ def _add_export_output_path_args(parser) -> None:
185186
default=None,
186187
help="Output to the specified AOT Inductor .dso model file",
187188
)
189+
190+
191+
def _add_export_args(parser) -> None:
188192
parser.add_argument(
189193
"--dynamic-shapes",
190194
action="store_true",
191195
help="Call torch.export with dynamic shapes",
192196
)
197+
parser.add_argument(
198+
"--max-seq-length",
199+
type=int,
200+
default=None,
201+
help="Set maximum length sequence when before calling torch.export",
202+
)
193203

194204

195205
# Add CLI Args representing user provided exported model files

export.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,19 @@ def main(args):
113113
except:
114114
tokenizer = None
115115

116+
if (
117+
output_dso_path is not None
118+
and builder_args.max_seq_length is None
119+
and not builder_args.dynamic_shapes
120+
):
121+
print("Setting max_seq_length to 300 for DSO export.")
122+
builder_args.max_seq_length = 300
123+
116124
model = _initialize_model(
117125
builder_args,
118126
quantize,
119127
tokenizer,
128+
max_seq_length=builder_args.max_seq_length,
120129
)
121130
model_to_pte = model
122131
model_to_dso = model

0 commit comments

Comments
 (0)