Skip to content

Commit db950ab

Browse files
committed
Merge remote-tracking branch 'origin/main' into remove-old-api
2 parents 42fe202 + 922a508 commit db950ab

File tree

7 files changed

+72
-146
lines changed

7 files changed

+72
-146
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 39 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ def build_args_parser() -> argparse.ArgumentParser:
322322
default="fp32",
323323
type=str,
324324
choices=["fp32", "fp16", "bf16"],
325-
help="Override the dtype of the model (default is the checkpoint dtype)."
326-
"Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16.",
325+
help="Provide the dtype of the model. This must match up with the supported dtypes of the backends that you are using."
326+
"Please be aware that only some backends support fp16 and bf16.",
327327
)
328328

329329
parser.add_argument(
@@ -565,43 +565,42 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
565565
output_dir_path = canonical_path(args.output_dir, dir=True)
566566
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
567567

568-
# dtype override
569-
if args.dtype_override is not None:
570-
dtype_override = DType[args.dtype_override]
571-
elif args.quantization_mode in ["8da4w", "8da4w-gptq"]:
572-
dtype_override = DType["fp16"]
573-
else:
574-
dtype_override = None
568+
# Convert dtype override string arg to actual type.
569+
dtype_override = DType[args.dtype_override]
570+
571+
edge_manager = _load_llama_model(
572+
args.model,
573+
checkpoint=checkpoint_path,
574+
checkpoint_dir=checkpoint_dir,
575+
params_path=params_path,
576+
use_kv_cache=args.use_kv_cache,
577+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
578+
generate_full_logits=args.generate_full_logits,
579+
weight_type=weight_type,
580+
enable_dynamic_shape=args.enable_dynamic_shape,
581+
calibration_tasks=args.calibration_tasks,
582+
calibration_limit=args.calibration_limit,
583+
calibration_seq_length=args.calibration_seq_length,
584+
calibration_data=args.calibration_data,
585+
tokenizer_path=args.tokenizer_path,
586+
verbose=args.verbose,
587+
max_seq_len=args.max_seq_length,
588+
max_context_len=args.max_context_length,
589+
input_prune_map_path=args.input_prune_map,
590+
output_prune_map_path=args.output_prune_map,
591+
metadata_str=args.metadata,
592+
dtype_override=dtype_override,
593+
args=args,
594+
)
575595

576-
return (
577-
_load_llama_model(
578-
args.model,
579-
checkpoint=checkpoint_path,
580-
checkpoint_dir=checkpoint_dir,
581-
params_path=params_path,
582-
use_kv_cache=args.use_kv_cache,
583-
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
584-
generate_full_logits=args.generate_full_logits,
585-
weight_type=weight_type,
586-
enable_dynamic_shape=args.enable_dynamic_shape,
587-
calibration_tasks=args.calibration_tasks,
588-
calibration_limit=args.calibration_limit,
589-
calibration_seq_length=args.calibration_seq_length,
590-
calibration_data=args.calibration_data,
591-
tokenizer_path=args.tokenizer_path,
592-
verbose=args.verbose,
593-
max_seq_len=args.max_seq_length,
594-
max_context_len=args.max_context_length,
595-
input_prune_map_path=args.input_prune_map,
596-
output_prune_map_path=args.output_prune_map,
597-
metadata_str=args.metadata,
598-
dtype_override=dtype_override,
599-
args=args,
600-
)
601-
.set_output_dir(output_dir_path)
602-
.source_transform(_get_source_transforms(args.model, dtype_override, args))
596+
# At this point, the model is loaded in the default fp32.
597+
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
598+
edge_manager.set_output_dir(output_dir_path).source_transform(
599+
_get_source_transforms(args.model, dtype_override, args)
603600
)
604601

602+
return edge_manager
603+
605604

606605
def get_quantizer_and_quant_params(args):
607606
pt2e_quant_params = get_pt2e_quantization_params(
@@ -1006,6 +1005,8 @@ def _load_llama_model(
10061005
else:
10071006
raise ValueError(f"{modelname} is not a valid Llama model.")
10081007

1008+
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
1009+
10091010
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
10101011
EagerModelFactory.create_model(
10111012
module_name,
@@ -1022,41 +1023,16 @@ def _load_llama_model(
10221023
enable_dynamic_shape=enable_dynamic_shape,
10231024
input_prune_map_path=input_prune_map_path,
10241025
output_prune_map_path=output_prune_map_path,
1026+
dtype=torch_dtype,
10251027
args=args,
10261028
)
10271029
)
1028-
if dtype_override:
1029-
assert isinstance(
1030-
dtype_override, DType
1031-
), "Override dtype needs to be of type <DType>"
1032-
torch_dtype = dtype_override.to_torch_dtype()
1033-
logging.info(f"model.to {torch_dtype}")
1034-
model = model.to(dtype=torch_dtype)
1035-
dtype = dtype_override
1036-
else:
1037-
state_dict = model.state_dict()
1038-
dtype = state_dict[next(iter(state_dict))].dtype
1039-
assert dtype in [
1040-
torch.bfloat16,
1041-
torch.float16,
1042-
torch.float32,
1043-
], f"Only support bfloat16, fp16 or fp32 got {dtype}"
1044-
logging.info(f"Loaded model with dtype={dtype}")
1045-
1046-
if dtype == torch.bfloat16:
1047-
dtype = DType.bf16
1048-
elif dtype == torch.float16:
1049-
dtype = DType.fp16
1050-
elif dtype == torch.float32:
1051-
dtype = DType.fp32
1052-
else:
1053-
raise ValueError(f"Unsupported dtype {dtype}")
10541030

10551031
return LLMEdgeManager(
10561032
model=model,
10571033
modelname=modelname,
10581034
max_seq_len=model.max_seq_len,
1059-
dtype=dtype,
1035+
dtype=dtype_override,
10601036
use_kv_cache=use_kv_cache,
10611037
generate_full_logits=generate_full_logits,
10621038
example_inputs=example_inputs,

examples/models/llama/model.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,6 @@ def __init__(self, **kwargs):
122122
"""
123123
)
124124

125-
# Get checkpoint dtype.
126-
self.dtype = get_checkpoint_dtype(checkpoint)
127-
128125
with open(params_path, "r") as f:
129126
params = json.loads(f.read())
130127
output_prune_map = None
@@ -171,7 +168,9 @@ def __init__(self, **kwargs):
171168
# Within the device="meta" context, tensors that are created do not carry data.
172169
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
173170
with torch.device("meta"):
171+
# Model itself is loaded in default dtype, fp32.
174172
self.model_ = Transformer(model_args)
173+
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
175174

176175
if "int8" in str(checkpoint_path):
177176
print("Using int8 weight-only quantization!")
@@ -241,6 +240,10 @@ def __init__(self, **kwargs):
241240
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
242241
# Because we are using device="meta", tensors do not have memory associated with them
243242
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
243+
244+
# Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
245+
# by default initialized to fp32. This is fine because every other supported type
246+
# losslessly converts to fp32, so we don't lose precision here.
244247
missing, unexpected = self.model_.load_state_dict(
245248
checkpoint,
246249
strict=False,
@@ -277,14 +280,7 @@ def __init__(self, **kwargs):
277280
self.model_ = prune_output_vocab(self.model_, output_prune_map)
278281

279282
def get_eager_model(self) -> torch.nn.Module:
280-
if self.dtype:
281-
# convert to the type of the provided checkpoint
282-
# input and output are torch.long, so signature unchanged
283-
return self.model_.to(self.dtype)
284-
else:
285-
# int8 quantization code has some bf16,
286-
# switch all to FP32
287-
return self.model_.to(torch.float32)
283+
return self.model_
288284

289285
def get_example_inputs(self):
290286
if self.use_kv_cache:

examples/qualcomm/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def execute(self, custom_runner_cmd=None, method_index=0):
205205
qnn_executor_runner_cmds = " ".join(
206206
[
207207
f"cd {self.workspace} &&",
208-
f"chmod +x ./qnn_executor_runner &&",
208+
"chmod +x ./qnn_executor_runner &&",
209209
f"./qnn_executor_runner {qnn_executor_runner_args}",
210210
]
211211
)

extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/Exported/ExecutorchRuntimeEngine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ NS_ASSUME_NONNULL_BEGIN
2121
modelMethodName:(NSString *)modelMethodName
2222
error:(NSError * _Nullable * _Nullable)error NS_DESIGNATED_INITIALIZER;
2323

24-
- (nullable NSArray<ExecutorchRuntimeValue *> *)infer:(NSArray<ExecutorchRuntimeValue *> *)input
24+
- (nullable NSArray<ExecutorchRuntimeValue *> *)infer:(NSArray<ExecutorchRuntimeValue *> *)values
2525
error:(NSError * _Nullable * _Nullable)error NS_SWIFT_NAME(infer(input:));
2626

2727
@end

extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/Exported/ExecutorchRuntimeEngine.mm

Lines changed: 22 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,6 @@
1313

1414
#import <executorch/extension/module/module.h>
1515

16-
static int kInitFailed = 0;
17-
static int kInferenceFailed = 1;
18-
19-
static auto NSStringToString(NSString *string) -> std::string
20-
{
21-
const char *cStr = [string cStringUsingEncoding:NSUTF8StringEncoding];
22-
if (cStr) {
23-
return cStr;
24-
}
25-
26-
NSData *data = [string dataUsingEncoding:NSUTF8StringEncoding allowLossyConversion:NO];
27-
return {reinterpret_cast<const char *>([data bytes]), [data length]};
28-
}
29-
30-
static auto StringToNSString(const std::string &string) -> NSString *
31-
{
32-
CFStringRef cfString = CFStringCreateWithBytes(
33-
kCFAllocatorDefault,
34-
reinterpret_cast<const UInt8 *>(string.c_str()),
35-
string.size(),
36-
kCFStringEncodingUTF8,
37-
false
38-
);
39-
return (__bridge_transfer NSString *)cfString;
40-
}
41-
4216
@implementation ExecutorchRuntimeEngine
4317
{
4418
NSString *_modelPath;
@@ -48,66 +22,47 @@ @implementation ExecutorchRuntimeEngine
4822

4923
- (instancetype)initWithModelPath:(NSString *)modelPath
5024
modelMethodName:(NSString *)modelMethodName
51-
error:(NSError * _Nullable * _Nullable)error
25+
error:(NSError **)error
5226
{
5327
if (self = [super init]) {
5428
_modelPath = modelPath;
5529
_modelMethodName = modelMethodName;
56-
try {
57-
_module = std::make_unique<torch::executor::Module>(NSStringToString(modelPath));
58-
const auto e = _module->load_method(NSStringToString(modelMethodName));
59-
if (e != executorch::runtime::Error::Ok) {
60-
if (error) {
61-
*error = [NSError errorWithDomain:@"ExecutorchRuntimeEngine"
62-
code:kInitFailed
63-
userInfo:@{NSDebugDescriptionErrorKey : StringToNSString(std::to_string(static_cast<uint32_t>(e)))}];
64-
}
65-
return nil;
66-
}
67-
} catch (...) {
30+
_module = std::make_unique<torch::executor::Module>(modelPath.UTF8String);
31+
const auto e = _module->load_method(modelMethodName.UTF8String);
32+
if (e != executorch::runtime::Error::Ok) {
6833
if (error) {
6934
*error = [NSError errorWithDomain:@"ExecutorchRuntimeEngine"
70-
code:kInitFailed
71-
userInfo:@{NSDebugDescriptionErrorKey : @"Unknown error"}];
35+
code:(NSInteger)e
36+
userInfo:nil];
7237
}
7338
return nil;
7439
}
7540
}
7641
return self;
7742
}
7843

79-
- (nullable NSArray<ExecutorchRuntimeValue *> *)infer:(NSArray<ExecutorchRuntimeValue *> *)input
80-
error:(NSError * _Nullable * _Nullable)error
44+
- (nullable NSArray<ExecutorchRuntimeValue *> *)infer:(NSArray<ExecutorchRuntimeValue *> *)values
45+
error:(NSError **)error
8146
{
82-
try {
83-
std::vector<torch::executor::EValue> inputEValues;
84-
inputEValues.reserve(input.count);
85-
for (ExecutorchRuntimeValue *inputValue in input) {
86-
inputEValues.push_back([inputValue getBackedValue]);
87-
}
88-
const auto result = _module->execute(NSStringToString(_modelMethodName), inputEValues);
89-
if (!result.ok()) {
90-
const auto executorchError = static_cast<uint32_t>(result.error());
91-
if (error) {
92-
*error = [NSError errorWithDomain:@"ExecutorchRuntimeEngine"
93-
code:kInferenceFailed
94-
userInfo:@{NSDebugDescriptionErrorKey : StringToNSString(std::to_string(executorchError))}];
95-
}
96-
return nil;
97-
}
98-
NSMutableArray<ExecutorchRuntimeValue *> *const resultValues = [NSMutableArray new];
99-
for (const auto &evalue : result.get()) {
100-
[resultValues addObject:[[ExecutorchRuntimeValue alloc] initWithEValue:evalue]];
101-
}
102-
return resultValues;
103-
} catch (...) {
47+
std::vector<torch::executor::EValue> inputEValues;
48+
inputEValues.reserve(values.count);
49+
for (ExecutorchRuntimeValue *inputValue in values) {
50+
inputEValues.push_back([inputValue getBackedValue]);
51+
}
52+
const auto result = _module->execute(_modelMethodName.UTF8String, inputEValues);
53+
if (!result.ok()) {
10454
if (error) {
10555
*error = [NSError errorWithDomain:@"ExecutorchRuntimeEngine"
106-
code:kInferenceFailed
107-
userInfo:@{NSDebugDescriptionErrorKey : @"Unknown error"}];
56+
code:(NSInteger)result.error()
57+
userInfo:nil];
10858
}
10959
return nil;
11060
}
61+
NSMutableArray<ExecutorchRuntimeValue *> *const resultValues = [NSMutableArray new];
62+
for (const auto &evalue : result.get()) {
63+
[resultValues addObject:[[ExecutorchRuntimeValue alloc] initWithEValue:evalue]];
64+
}
65+
return resultValues;
11166
}
11267

11368
@end

extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/__tests__/ExecutorchRuntimeEngineTests.mm

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ - (void)testInvalidModel
2626
XCTAssertNil(engine);
2727
XCTAssertNotNil(runtimeInitError);
2828

29-
XCTAssertEqual(runtimeInitError.code, 0);
30-
XCTAssertEqualObjects(runtimeInitError.userInfo[NSDebugDescriptionErrorKey], @"34");
29+
XCTAssertEqual(runtimeInitError.code, 34);
3130
// 34 is the code for AccessFailed.
3231
}
3332

extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ phases:
111111

112112
adb -s $DEVICEFARM_DEVICE_UDID shell dumpsys deviceidle force-idle
113113
adb -s $DEVICEFARM_DEVICE_UDID shell dumpsys deviceidle unforce
114-
adb -s $DEVICEFARM_DEVICE_UDID shell sleep 30
114+
adb -s $DEVICEFARM_DEVICE_UDID shell sleep 180
115115

116116
if [ -n "$BIN_FOUND" ]; then
117117
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \

0 commit comments

Comments
 (0)