Skip to content

Commit 75f1eea

Browse files
authored
Support different model types (Llama, Qwen3, Llava) in iOS app (#10615)
1 parent bdca6e3 commit 75f1eea

File tree

3 files changed

+43
-11
lines changed

3 files changed

+43
-11
lines changed

examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
306A71502DC1DC3D00936B1F /* regex.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A71492DC1DC3D00936B1F /* regex.cpp */; };
6363
306A71512DC1DC3D00936B1F /* pre_tokenizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A71472DC1DC3D00936B1F /* pre_tokenizer.cpp */; };
6464
306A71522DC1DC3D00936B1F /* token_decoder.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A714B2DC1DC3D00936B1F /* token_decoder.cpp */; };
65+
3072D5232DC3EA280083FC83 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3072D5222DC3EA280083FC83 /* Constants.swift */; };
6566
F292B0752D88B0C200BE6839 /* tiktoken.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06F2D88B0C200BE6839 /* tiktoken.cpp */; };
6667
F292B0762D88B0C200BE6839 /* llama2c_tokenizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06C2D88B0C200BE6839 /* llama2c_tokenizer.cpp */; };
6768
F292B0772D88B0C200BE6839 /* bpe_tokenizer_base.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06A2D88B0C200BE6839 /* bpe_tokenizer_base.cpp */; };
@@ -147,6 +148,7 @@
147148
306A71492DC1DC3D00936B1F /* regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = regex.cpp; path = src/regex.cpp; sourceTree = "<group>"; };
148149
306A714A2DC1DC3D00936B1F /* std_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = std_regex.cpp; path = src/std_regex.cpp; sourceTree = "<group>"; };
149150
306A714B2DC1DC3D00936B1F /* token_decoder.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = token_decoder.cpp; path = src/token_decoder.cpp; sourceTree = "<group>"; };
151+
3072D5222DC3EA280083FC83 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = "<group>"; };
150152
F292B06A2D88B0C200BE6839 /* bpe_tokenizer_base.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = bpe_tokenizer_base.cpp; path = src/bpe_tokenizer_base.cpp; sourceTree = "<group>"; };
151153
F292B06C2D88B0C200BE6839 /* llama2c_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = llama2c_tokenizer.cpp; path = src/llama2c_tokenizer.cpp; sourceTree = "<group>"; };
152154
F292B06F2D88B0C200BE6839 /* tiktoken.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = tiktoken.cpp; path = src/tiktoken.cpp; sourceTree = "<group>"; };
@@ -208,6 +210,7 @@
208210
0324D6892BAACB6900DEF36F /* Application */ = {
209211
isa = PBXGroup;
210212
children = (
213+
3072D5222DC3EA280083FC83 /* Constants.swift */,
211214
0324D6802BAACB6900DEF36F /* App.swift */,
212215
0324D6812BAACB6900DEF36F /* ContentView.swift */,
213216
0324D6822BAACB6900DEF36F /* LogManager.swift */,
@@ -554,6 +557,7 @@
554557
buildActionMask = 2147483647;
555558
files = (
556559
0324D6932BAACB6900DEF36F /* ResourceMonitor.swift in Sources */,
560+
3072D5232DC3EA280083FC83 /* Constants.swift in Sources */,
557561
0324D68D2BAACB6900DEF36F /* LogManager.swift in Sources */,
558562
0324D68E2BAACB6900DEF36F /* LogView.swift in Sources */,
559563
0324D68F2BAACB6900DEF36F /* Message.swift in Sources */,

examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,25 @@ struct ContentView: View {
8080
case tokenizer
8181
}
8282

83+
enum ModelType {
84+
case llama
85+
case llava
86+
case qwen3
87+
88+
static func fromPath(_ path: String) -> ModelType {
89+
let filename = (path as NSString).lastPathComponent.lowercased()
90+
if filename.hasPrefix("llama") {
91+
return .llama
92+
} else if filename.hasPrefix("llava") {
93+
return .llava
94+
} else if filename.hasPrefix("qwen3") {
95+
return .qwen3
96+
}
97+
print("Unknown model type in path: \(path). Model filename should start with one of: llama, llava, or qwen3")
98+
exit(1)
99+
}
100+
}
101+
83102
private var placeholder: String {
84103
resourceManager.isModelValid ? resourceManager.isTokenizerValid ? "Prompt..." : "Select Tokenizer..." : "Select Model..."
85104
}
@@ -275,14 +294,14 @@ struct ContentView: View {
275294
let seq_len = 768 // text: 256, vision: 768
276295
let modelPath = resourceManager.modelPath
277296
let tokenizerPath = resourceManager.tokenizerPath
278-
let useLlama = modelPath.lowercased().contains("llama")
297+
let modelType = ModelType.fromPath(modelPath)
279298

280299
prompt = ""
281300
hideKeyboard()
282301
showingSettings = false
283302

284303
messages.append(Message(text: text))
285-
messages.append(Message(type: useLlama ? .llamagenerated : .llavagenerated))
304+
messages.append(Message(type: modelType == .llama ? .llamagenerated : .llavagenerated))
286305

287306
runnerQueue.async {
288307
defer {
@@ -292,14 +311,16 @@ struct ContentView: View {
292311
}
293312
}
294313

295-
if useLlama {
314+
switch modelType {
315+
case .llama, .qwen3:
296316
runnerHolder.runner = runnerHolder.runner ?? Runner(modelPath: modelPath, tokenizerPath: tokenizerPath)
297-
} else {
317+
case .llava:
298318
runnerHolder.llavaRunner = runnerHolder.llavaRunner ?? LLaVARunner(modelPath: modelPath, tokenizerPath: tokenizerPath)
299319
}
300320

301321
guard !shouldStopGenerating else { return }
302-
if useLlama {
322+
switch modelType {
323+
case .llama, .qwen3:
303324
if let runner = runnerHolder.runner, !runner.isLoaded() {
304325
var error: Error?
305326
let startLoadTime = Date()
@@ -329,7 +350,7 @@ struct ContentView: View {
329350
return
330351
}
331352
}
332-
} else {
353+
case .llava:
333354
if let runner = runnerHolder.llavaRunner, !runner.isLoaded() {
334355
var error: Error?
335356
let startLoadTime = Date()
@@ -411,12 +432,19 @@ struct ContentView: View {
411432
}
412433
}
413434
} else {
414-
let llama3_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\(text)<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
435+
let prompt: String
436+
switch modelType {
437+
case .qwen3:
438+
prompt = String(format: Constants.qwen3PromptTemplate, text)
439+
case .llama:
440+
prompt = String(format: Constants.llama3PromptTemplate, text)
441+
case .llava:
442+
prompt = String(format: Constants.llama3PromptTemplate, text)
443+
}
415444

416-
try runnerHolder.runner?.generate(llama3_prompt, sequenceLength: seq_len) { token in
445+
try runnerHolder.runner?.generate(prompt, sequenceLength: seq_len) { token in
417446

418-
NSLog(">>> token={\(token)}")
419-
if token != llama3_prompt {
447+
if token != prompt {
420448
// hack to fix the issue that extension/llm/runner/text_token_generator.h
421449
// keeps generating after <|eot_id|>
422450
if token == "<|eot_id|>" {

examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/Message.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import UIKit
1010

1111
enum MessageType {
1212
case prompted
13-
case llamagenerated
13+
case llamagenerated // TODO: change this to to something more general, like "textgenerated".
1414
case llavagenerated
1515
case info
1616
}

0 commit comments

Comments
 (0)