Skip to content

Commit b8eafc2

Browse files
committed
Merge remote-tracking branch 'origin/main' into jni-layer-cpp
2 parents 2ce4b27 + 06b946e commit b8eafc2

File tree

19 files changed

+676
-642
lines changed

19 files changed

+676
-642
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ xcuserdata/
4040
.swiftpm/
4141
*.xcworkspace/
4242
*.xcframework/
43+
44+
# Android
45+
*.aar

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ python_library(
211211
typing = True,
212212
deps = [
213213
":pass_utils",
214+
":utils",
214215
"//executorch/backends/cadence/aot:pass_utils",
215216
"//executorch/exir:pass_base",
216217
"//executorch/exir/dialects:lib",

backends/cadence/aot/simplify_ops.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
CadencePassAttribute,
1717
register_cadence_pass,
1818
)
19+
from executorch.backends.cadence.aot.utils import rebind
1920
from executorch.exir.dialects._ops import ops as exir_ops
2021
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2122
from executorch.exir.pass_base import ExportPass, ProxyValue
22-
from torch.fx.operator_schemas import get_signature_for_torch_op
2323

2424

2525
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -117,32 +117,11 @@ class BindOptionalArgsPass(ExportPass):
117117
def call_operator(self, op, args, kwargs, meta):
118118
if not isinstance(op, EdgeOpOverload):
119119
return super().call_operator(op, args, kwargs, meta)
120-
assert callable(op)
121120

122-
torch_op_schemas = get_signature_for_torch_op(op._op)
123-
if len(torch_op_schemas) == 0:
124-
return super().call_operator(op, args, kwargs, meta)
125-
126-
matched_schemas = []
127-
# Iterate through all of the schema until we find one that matches
128-
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
129-
# values. If none matches, `new_args_and_kwargs` will be None
130-
for candidate_signature in torch_op_schemas:
131-
try:
132-
candidate_signature.bind(*args, **kwargs)
133-
matched_schemas.append(candidate_signature)
134-
except TypeError:
135-
continue
136-
137-
if len(matched_schemas) != 1:
138-
# Did not match any schema. Cannot normalize
139-
return super().call_operator(op, args, kwargs, meta)
140-
141-
sig = matched_schemas[0]
142-
bound_args = sig.bind(*args, **kwargs)
143-
bound_args.apply_defaults()
121+
if (updated_args := rebind(op, args, kwargs)) is not None:
122+
args, kwargs = updated_args
144123

145-
return super().call_operator(op, bound_args.args, bound_args.kwargs, meta)
124+
return super().call_operator(op, args, kwargs, meta)
146125

147126

148127
# This class encapsulates all the functions that simplify the op's args

backends/cadence/aot/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from executorch.exir import ExecutorchProgramManager, memory
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
21+
from executorch.exir.pass_base import Argument
2122
from tabulate import tabulate
23+
from torch.fx.operator_schemas import get_signature_for_torch_op
2224

2325
from torch.utils._pytree import tree_flatten
2426

@@ -308,3 +310,30 @@ def get_size(self, exir_id: int) -> int:
308310
# Return default memory config for the backend
309311
def get_default_memory_config() -> MemoryConfig:
310312
return MemoryConfig(memory_sizes=[0x1000000000])
313+
314+
315+
def rebind(
316+
op: EdgeOpOverload, args: tuple[Argument, ...], kwargs: dict[str, Argument]
317+
) -> Optional[tuple[tuple[Argument, ...], dict[str, Argument]]]:
318+
"""Populates optional args and binds args/kwargs based on schema."""
319+
torch_op_schemas = get_signature_for_torch_op(op._op)
320+
321+
matched_schemas = []
322+
# Iterate through all of the schema until we find one that matches
323+
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
324+
# values. If none matches, `new_args_and_kwargs` will be None
325+
for candidate_signature in torch_op_schemas:
326+
try:
327+
candidate_signature.bind(*args, **kwargs)
328+
matched_schemas.append(candidate_signature)
329+
except TypeError:
330+
continue
331+
332+
if len(matched_schemas) != 1:
333+
# Did not match any schema. Cannot normalize
334+
return None
335+
336+
bound_args = matched_schemas[0].bind(*args, **kwargs)
337+
bound_args.apply_defaults()
338+
339+
return bound_args.args, bound_args.kwargs

examples/models/llama/model_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class ModelArgs:
1414
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
1515
ffn_dim_multiplier: Optional[float] = None
1616
norm_eps: float = 1e-5
17-
max_batch_size: int = 32
17+
max_batch_size: int = 1
1818
max_seq_len: int = 2048
1919
max_context_len: int = 2048
2020
moe: bool = False # True to enable the MoE (Mixture of Experts)

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.java

Lines changed: 0 additions & 126 deletions
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
package org.pytorch.executorch
9+
10+
import android.Manifest
11+
import androidx.test.InstrumentationRegistry
12+
import androidx.test.ext.junit.runners.AndroidJUnit4
13+
import androidx.test.rule.GrantPermissionRule
14+
import org.apache.commons.io.FileUtils
15+
import org.json.JSONException
16+
import org.json.JSONObject
17+
import org.junit.Assert
18+
import org.junit.Before
19+
import org.junit.Rule
20+
import org.junit.Test
21+
import org.junit.runner.RunWith
22+
import org.pytorch.executorch.extension.llm.LlmCallback
23+
import org.pytorch.executorch.extension.llm.LlmModule
24+
import java.io.File
25+
import java.io.IOException
26+
import java.net.URISyntaxException
27+
28+
/** Unit tests for [org.pytorch.executorch.extension.llm.LlmModule]. */
29+
@RunWith(AndroidJUnit4::class)
30+
class LlmModuleInstrumentationTest : LlmCallback {
31+
private val results: MutableList<String> = ArrayList()
32+
private val tokensPerSecond: MutableList<Float> = ArrayList()
33+
private var llmModule: LlmModule? = null
34+
35+
@Before
36+
@Throws(IOException::class)
37+
fun setUp() {
38+
// copy zipped test resources to local device
39+
val addPteFile = File(getTestFilePath(TEST_FILE_NAME))
40+
var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME)
41+
FileUtils.copyInputStreamToFile(inputStream, addPteFile)
42+
inputStream.close()
43+
44+
val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME))
45+
inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME)
46+
FileUtils.copyInputStreamToFile(inputStream, tokenizerFile)
47+
inputStream.close()
48+
49+
llmModule =
50+
LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f)
51+
}
52+
53+
@get:Rule
54+
var runtimePermissionRule: GrantPermissionRule =
55+
GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE)
56+
57+
@Test
58+
@Throws(IOException::class, URISyntaxException::class)
59+
fun testGenerate() {
60+
val loadResult = llmModule!!.load()
61+
// Check that the model can be load successfully
62+
Assert.assertEquals(OK.toLong(), loadResult.toLong())
63+
64+
llmModule!!.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
65+
Assert.assertEquals(results.size.toLong(), SEQ_LEN.toLong())
66+
Assert.assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0)
67+
}
68+
69+
@Test
70+
@Throws(IOException::class, URISyntaxException::class)
71+
fun testGenerateAndStop() {
72+
llmModule!!.generate(TEST_PROMPT, SEQ_LEN, object : LlmCallback {
73+
override fun onResult(result: String) {
74+
this@LlmModuleInstrumentationTest.onResult(result)
75+
llmModule!!.stop()
76+
}
77+
78+
override fun onStats(stats: String) {
79+
this@LlmModuleInstrumentationTest.onStats(stats)
80+
}
81+
})
82+
83+
val stoppedResultSize = results.size
84+
Assert.assertTrue(stoppedResultSize < SEQ_LEN)
85+
}
86+
87+
override fun onResult(result: String) {
88+
results.add(result)
89+
}
90+
91+
override fun onStats(stats: String) {
92+
var tps = 0f
93+
try {
94+
val jsonObject = JSONObject(stats)
95+
val numGeneratedTokens = jsonObject.getInt("generated_tokens")
96+
val inferenceEndMs = jsonObject.getInt("inference_end_ms")
97+
val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms")
98+
tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000
99+
tokensPerSecond.add(tps)
100+
} catch (_: JSONException) {
101+
}
102+
}
103+
104+
companion object {
105+
private const val TEST_FILE_NAME = "/stories.pte"
106+
private const val TOKENIZER_FILE_NAME = "/tokenizer.bin"
107+
private const val TEST_PROMPT = "Hello"
108+
private const val OK = 0x00
109+
private const val SEQ_LEN = 32
110+
111+
private fun getTestFilePath(fileName: String): String {
112+
return InstrumentationRegistry.getInstrumentation().targetContext.externalCacheDir.toString() + fileName
113+
}
114+
}
115+
}

0 commit comments

Comments
 (0)