Skip to content

Commit b8a2cbd

Browse files
authored
Add LLaVa runner.
Differential Revision: D62142005 Pull Request resolved: #5053
1 parent 9ae7c0d commit b8a2cbd

File tree

4 files changed

+152
-2
lines changed

4 files changed

+152
-2
lines changed

examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
03729F132BB2042B00152F2E /* sampler.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 03729F112BB2042B00152F2E /* sampler.cpp */; };
4444
03729F162BB2043600152F2E /* bpe_tokenizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 03729F142BB2043600152F2E /* bpe_tokenizer.cpp */; };
4545
03729F172BB2043600152F2E /* tokenizer.h in Headers */ = {isa = PBXBuildFile; fileRef = 03729F152BB2043600152F2E /* tokenizer.h */; };
46+
0372C3112C893FE900CD942A /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 0372C3102C893FE900CD942A /* CoreGraphics.framework */; };
47+
0372C3142C89418E00CD942A /* llava_runner.h in Headers */ = {isa = PBXBuildFile; fileRef = 0372C3122C89418E00CD942A /* llava_runner.h */; };
48+
0372C3152C89418E00CD942A /* llava_runner.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 0372C3132C89418E00CD942A /* llava_runner.cpp */; };
4649
038D678C2C482C1E00B88CF2 /* llama_tiktoken.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 038D678A2C482C1D00B88CF2 /* llama_tiktoken.cpp */; };
4750
038D678D2C482C1E00B88CF2 /* llama_tiktoken.h in Headers */ = {isa = PBXBuildFile; fileRef = 038D678B2C482C1E00B88CF2 /* llama_tiktoken.h */; };
4851
03BADE202BD2E88600DDFDC2 /* bpe_tokenizer.h in Headers */ = {isa = PBXBuildFile; fileRef = 03BADE1F2BD2E88600DDFDC2 /* bpe_tokenizer.h */; };
@@ -141,11 +144,14 @@
141144
03729ED52BB1F8DE00152F2E /* LLaMARunner.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = LLaMARunner.framework; sourceTree = BUILT_PRODUCTS_DIR; };
142145
03729F072BB203B300152F2E /* runner.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = runner.cpp; path = ../../../examples/models/llama2/runner/runner.cpp; sourceTree = "<group>"; };
143146
03729F082BB203B300152F2E /* runner.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = runner.h; path = ../../../examples/models/llama2/runner/runner.h; sourceTree = "<group>"; };
144-
03729F092BB203B300152F2E /* util.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = util.h; path = ../../../../extension/llm/runner/util.h; sourceTree = "<group>"; };
147+
03729F092BB203B300152F2E /* util.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = util.h; sourceTree = "<group>"; };
145148
03729F102BB2042B00152F2E /* sampler.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = sampler.h; sourceTree = "<group>"; };
146149
03729F112BB2042B00152F2E /* sampler.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = sampler.cpp; sourceTree = "<group>"; };
147150
03729F142BB2043600152F2E /* bpe_tokenizer.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = bpe_tokenizer.cpp; path = ../../../../extension/llm/tokenizer/bpe_tokenizer.cpp; sourceTree = "<group>"; };
148151
03729F152BB2043600152F2E /* tokenizer.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = tokenizer.h; path = ../../../../extension/llm/tokenizer/tokenizer.h; sourceTree = "<group>"; };
152+
0372C3102C893FE900CD942A /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; };
153+
0372C3122C89418E00CD942A /* llava_runner.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = llava_runner.h; path = ../../../examples/models/llava/runner/llava_runner.h; sourceTree = "<group>"; };
154+
0372C3132C89418E00CD942A /* llava_runner.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = llava_runner.cpp; path = ../../../examples/models/llava/runner/llava_runner.cpp; sourceTree = "<group>"; };
149155
038D678A2C482C1D00B88CF2 /* llama_tiktoken.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = llama_tiktoken.cpp; sourceTree = "<group>"; };
150156
038D678B2C482C1E00B88CF2 /* llama_tiktoken.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = llama_tiktoken.h; sourceTree = "<group>"; };
151157
03BADE1F2BD2E88600DDFDC2 /* bpe_tokenizer.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = bpe_tokenizer.h; path = ../../../../extension/llm/tokenizer/bpe_tokenizer.h; sourceTree = "<group>"; };
@@ -190,6 +196,7 @@
190196
isa = PBXFrameworksBuildPhase;
191197
buildActionMask = 2147483647;
192198
files = (
199+
0372C3112C893FE900CD942A /* CoreGraphics.framework in Frameworks */,
193200
03312C3E2BBFD076002106EF /* executorch_debug in Frameworks */,
194201
);
195202
runOnlyForDeploymentPostprocessing = 0;
@@ -323,6 +330,8 @@
323330
03729F062BB2035900152F2E /* runner */ = {
324331
isa = PBXGroup;
325332
children = (
333+
0372C3132C89418E00CD942A /* llava_runner.cpp */,
334+
0372C3122C89418E00CD942A /* llava_runner.h */,
326335
03729F072BB203B300152F2E /* runner.cpp */,
327336
03729F082BB203B300152F2E /* runner.h */,
328337
03D03DA92C7823830088D6A7 /* text_decoder_runner.cpp */,
@@ -373,6 +382,7 @@
373382
84DD947F2C81060E00C765A6 /* Frameworks */ = {
374383
isa = PBXGroup;
375384
children = (
385+
0372C3102C893FE900CD942A /* CoreGraphics.framework */,
376386
);
377387
name = Frameworks;
378388
sourceTree = "<group>";
@@ -403,6 +413,7 @@
403413
038D678D2C482C1E00B88CF2 /* llama_tiktoken.h in Headers */,
404414
03729F0C2BB203B300152F2E /* util.h in Headers */,
405415
03729F0B2BB203B300152F2E /* runner.h in Headers */,
416+
0372C3142C89418E00CD942A /* llava_runner.h in Headers */,
406417
);
407418
runOnlyForDeploymentPostprocessing = 0;
408419
};
@@ -646,6 +657,7 @@
646657
03729EE12BB1F93800152F2E /* LLaMARunner.mm in Sources */,
647658
03BADE232BD2EB6700DDFDC2 /* tiktoken.cpp in Sources */,
648659
038D678C2C482C1E00B88CF2 /* llama_tiktoken.cpp in Sources */,
660+
0372C3152C89418E00CD942A /* llava_runner.cpp in Sources */,
649661
03D03DAB2C7823830088D6A7 /* text_decoder_runner.cpp in Sources */,
650662
03729F162BB2043600152F2E /* bpe_tokenizer.cpp in Sources */,
651663
03729F0A2BB203B300152F2E /* runner.cpp in Sources */,

examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import LLaMARunner
1313

1414
class RunnerHolder: ObservableObject {
1515
var runner: Runner?
16+
var llavaRunner: LLaVARunner?
1617
}
1718

1819
struct ContentView: View {

examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#import <Foundation/Foundation.h>
9+
#import <UIKit/UIKit.h>
1010

1111
NS_ASSUME_NONNULL_BEGIN
1212

1313
FOUNDATION_EXPORT NSErrorDomain const LLaMARunnerErrorDomain;
14+
FOUNDATION_EXPORT NSErrorDomain const LLaVARunnerErrorDomain;
1415

1516
NS_SWIFT_NAME(Runner)
1617
@interface LLaMARunner : NSObject
@@ -23,6 +24,30 @@ NS_SWIFT_NAME(Runner)
2324
sequenceLength:(NSInteger)seq_len
2425
withTokenCallback:(nullable void (^)(NSString*))callback
2526
error:(NSError**)error;
27+
- (BOOL)generate:(NSArray<UIImage*>*)images
28+
prompt:(NSString*)prompt
29+
sequenceLength:(NSInteger)seq_len
30+
withTokenCallback:(nullable void (^)(NSString*))callback
31+
error:(NSError**)error;
32+
- (void)stop;
33+
34+
+ (instancetype)new NS_UNAVAILABLE;
35+
- (instancetype)init NS_UNAVAILABLE;
36+
37+
@end
38+
39+
NS_SWIFT_NAME(LLaVARunner)
40+
@interface LLaVARunner : NSObject
41+
42+
- (instancetype)initWithModelPath:(NSString*)filePath
43+
tokenizerPath:(NSString*)tokenizerPath;
44+
- (BOOL)isloaded;
45+
- (BOOL)loadWithError:(NSError**)error;
46+
- (BOOL)generate:(NSArray<UIImage*>*)images
47+
prompt:(NSString*)prompt
48+
sequenceLength:(NSInteger)seq_len
49+
withTokenCallback:(nullable void (^)(NSString*))callback
50+
error:(NSError**)error;
2651
- (void)stop;
2752

2853
+ (instancetype)new NS_UNAVAILABLE;

examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010

1111
#import <ExecuTorch/ExecuTorchLog.h>
1212
#import <executorch/examples/models/llama2/runner/runner.h>
13+
#import <executorch/examples/models/llava/runner/llava_runner.h>
1314

1415
using namespace ::torch::executor;
1516

1617
NSErrorDomain const LLaMARunnerErrorDomain = @"LLaMARunnerErrorDomain";
18+
NSErrorDomain const LLaVARunnerErrorDomain = @"LLaVARunnerErrorDomain";
1719

1820
@interface LLaMARunner ()<ExecuTorchLogSink>
1921
@end
@@ -102,3 +104,113 @@ - (void)logWithLevel:(ExecuTorchLogLevel)level
102104
}
103105

104106
@end
107+
108+
@interface LLaVARunner ()<ExecuTorchLogSink>
109+
@end
110+
111+
@implementation LLaVARunner {
112+
std::unique_ptr<LlavaRunner> _runner;
113+
}
114+
115+
- (instancetype)initWithModelPath:(NSString*)modelPath
116+
tokenizerPath:(NSString*)tokenizerPath {
117+
self = [super init];
118+
if (self) {
119+
[ExecuTorchLog.sharedLog addSink:self];
120+
_runner = std::make_unique<LlavaRunner>(
121+
modelPath.UTF8String, tokenizerPath.UTF8String);
122+
}
123+
return self;
124+
}
125+
126+
- (void)dealloc {
127+
[ExecuTorchLog.sharedLog removeSink:self];
128+
}
129+
130+
- (BOOL)isloaded {
131+
return _runner->is_loaded();
132+
}
133+
134+
- (BOOL)loadWithError:(NSError**)error {
135+
const auto status = _runner->load();
136+
if (status != Error::Ok) {
137+
if (error) {
138+
*error = [NSError errorWithDomain:LLaVARunnerErrorDomain
139+
code:(NSInteger)status
140+
userInfo:nil];
141+
}
142+
return NO;
143+
}
144+
return YES;
145+
}
146+
147+
- (BOOL)generate:(NSArray<UIImage*>*)images
148+
prompt:(NSString*)prompt
149+
sequenceLength:(NSInteger)seq_len
150+
withTokenCallback:(nullable void (^)(NSString*))callback
151+
error:(NSError**)error {
152+
std::vector<Image> rawImages;
153+
rawImages.reserve(images.count);
154+
155+
for (UIImage* image in images) {
156+
CGImageRef cgImage = image.CGImage;
157+
const int32_t width = CGImageGetWidth(cgImage);
158+
const int32_t height = CGImageGetHeight(cgImage);
159+
std::vector<uint8_t> buffer(height * width * 4);
160+
CGContextRef context = CGBitmapContextCreate(
161+
buffer.data(),
162+
width,
163+
height,
164+
8,
165+
width * 4,
166+
CGColorSpaceCreateDeviceRGB(),
167+
kCGImageAlphaPremultipliedLast);
168+
CGContextDrawImage(context, CGRectMake(0, 0, width, height), cgImage);
169+
CGContextRelease(context);
170+
rawImages.push_back({std::move(buffer), width, height, 4});
171+
}
172+
const auto status = _runner->generate(
173+
std::move(rawImages),
174+
prompt.UTF8String,
175+
seq_len,
176+
[callback](const std::string& token) { callback(@(token.c_str())); });
177+
if (status != Error::Ok) {
178+
if (error) {
179+
*error = [NSError errorWithDomain:LLaVARunnerErrorDomain
180+
code:(NSInteger)status
181+
userInfo:nil];
182+
return NO;
183+
}
184+
}
185+
return YES;
186+
}
187+
188+
- (void)stop {
189+
_runner->stop();
190+
}
191+
192+
#pragma mark - ExecuTorchLogSink
193+
194+
- (void)logWithLevel:(ExecuTorchLogLevel)level
195+
timestamp:(NSTimeInterval)timestamp
196+
filename:(NSString*)filename
197+
line:(NSUInteger)line
198+
message:(NSString*)message {
199+
NSUInteger totalSeconds = (NSUInteger)timestamp;
200+
NSUInteger hours = (totalSeconds / 3600) % 24;
201+
NSUInteger minutes = (totalSeconds / 60) % 60;
202+
NSUInteger seconds = totalSeconds % 60;
203+
NSUInteger microseconds = (timestamp - totalSeconds) * 1000000;
204+
NSLog(
205+
@"%c %02lu:%02lu:%02lu.%06lu executorch:%s:%zu] %s",
206+
(char)level,
207+
hours,
208+
minutes,
209+
seconds,
210+
microseconds,
211+
filename.UTF8String,
212+
line,
213+
message.UTF8String);
214+
}
215+
216+
@end

0 commit comments

Comments
 (0)