Skip to content

Commit 5e24bdb

Browse files
hsharma35GregoryComer
authored andcommitted
Create utility to rebind args/kwargs.
Differential Revision: D75029675 Pull Request resolved: #10987
1 parent 2ede762 commit 5e24bdb

File tree

7 files changed

+54
-36
lines changed

7 files changed

+54
-36
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ xcuserdata/
4040
.swiftpm/
4141
*.xcworkspace/
4242
*.xcframework/
43+
44+
# Android
45+
*.aar

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ python_library(
211211
typing = True,
212212
deps = [
213213
":pass_utils",
214+
":utils",
214215
"//executorch/backends/cadence/aot:pass_utils",
215216
"//executorch/exir:pass_base",
216217
"//executorch/exir/dialects:lib",

backends/cadence/aot/simplify_ops.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
CadencePassAttribute,
1717
register_cadence_pass,
1818
)
19+
from executorch.backends.cadence.aot.utils import rebind
1920
from executorch.exir.dialects._ops import ops as exir_ops
2021
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2122
from executorch.exir.pass_base import ExportPass, ProxyValue
22-
from torch.fx.operator_schemas import get_signature_for_torch_op
2323

2424

2525
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -117,32 +117,11 @@ class BindOptionalArgsPass(ExportPass):
117117
def call_operator(self, op, args, kwargs, meta):
118118
if not isinstance(op, EdgeOpOverload):
119119
return super().call_operator(op, args, kwargs, meta)
120-
assert callable(op)
121120

122-
torch_op_schemas = get_signature_for_torch_op(op._op)
123-
if len(torch_op_schemas) == 0:
124-
return super().call_operator(op, args, kwargs, meta)
125-
126-
matched_schemas = []
127-
# Iterate through all of the schema until we find one that matches
128-
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
129-
# values. If none matches, `new_args_and_kwargs` will be None
130-
for candidate_signature in torch_op_schemas:
131-
try:
132-
candidate_signature.bind(*args, **kwargs)
133-
matched_schemas.append(candidate_signature)
134-
except TypeError:
135-
continue
136-
137-
if len(matched_schemas) != 1:
138-
# Did not match any schema. Cannot normalize
139-
return super().call_operator(op, args, kwargs, meta)
140-
141-
sig = matched_schemas[0]
142-
bound_args = sig.bind(*args, **kwargs)
143-
bound_args.apply_defaults()
121+
if (updated_args := rebind(op, args, kwargs)) is not None:
122+
args, kwargs = updated_args
144123

145-
return super().call_operator(op, bound_args.args, bound_args.kwargs, meta)
124+
return super().call_operator(op, args, kwargs, meta)
146125

147126

148127
# This class encapsulates all the functions that simplify the op's args

backends/cadence/aot/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from executorch.exir import ExecutorchProgramManager, memory
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
21+
from executorch.exir.pass_base import Argument
2122
from tabulate import tabulate
23+
from torch.fx.operator_schemas import get_signature_for_torch_op
2224

2325
from torch.utils._pytree import tree_flatten
2426

@@ -308,3 +310,30 @@ def get_size(self, exir_id: int) -> int:
308310
# Return default memory config for the backend
309311
def get_default_memory_config() -> MemoryConfig:
310312
return MemoryConfig(memory_sizes=[0x1000000000])
313+
314+
315+
def rebind(
316+
op: EdgeOpOverload, args: tuple[Argument, ...], kwargs: dict[str, Argument]
317+
) -> Optional[tuple[tuple[Argument, ...], dict[str, Argument]]]:
318+
"""Populates optional args and binds args/kwargs based on schema."""
319+
torch_op_schemas = get_signature_for_torch_op(op._op)
320+
321+
matched_schemas = []
322+
# Iterate through all of the schema until we find one that matches
323+
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
324+
# values. If none matches, `new_args_and_kwargs` will be None
325+
for candidate_signature in torch_op_schemas:
326+
try:
327+
candidate_signature.bind(*args, **kwargs)
328+
matched_schemas.append(candidate_signature)
329+
except TypeError:
330+
continue
331+
332+
if len(matched_schemas) != 1:
333+
# Did not match any schema. Cannot normalize
334+
return None
335+
336+
bound_args = matched_schemas[0].bind(*args, **kwargs)
337+
bound_args.apply_defaults()
338+
339+
return bound_args.args, bound_args.kwargs

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,19 @@ public class Module {
4747
*
4848
* @param modelPath path to file that contains the serialized ExecuTorch module.
4949
* @param loadMode load mode for the module. See constants in {@link Module}.
50+
* @param numThreads the number of threads to use for inference. A value of 0 defaults to a
51+
* hardware-specific default.
5052
* @return new {@link org.pytorch.executorch.Module} object which owns the model module.
5153
*/
52-
public static Module load(final String modelPath, int loadMode) {
54+
public static Module load(final String modelPath, int loadMode, int numThreads = 0) {
5355
if (!NativeLoader.isInitialized()) {
5456
NativeLoader.init(new SystemDelegate());
5557
}
5658
File modelFile = new File(modelPath);
5759
if (!modelFile.canRead() || !modelFile.isFile()) {
5860
throw new RuntimeException("Cannot load model path " + modelPath);
5961
}
60-
return new Module(new NativePeer(modelPath, loadMode));
62+
return new Module(new NativePeer(modelPath, loadMode, numThreads));
6163
}
6264

6365
/**

extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ class NativePeer {
2828
private final HybridData mHybridData;
2929

3030
@DoNotStrip
31-
private static native HybridData initHybrid(String moduleAbsolutePath, int loadMode);
31+
private static native HybridData initHybrid(
32+
String moduleAbsolutePath, int loadMode, int numThreads);
3233

33-
NativePeer(String moduleAbsolutePath, int loadMode) {
34-
mHybridData = initHybrid(moduleAbsolutePath, loadMode);
34+
NativePeer(String moduleAbsolutePath, int loadMode, int numThreads) {
35+
mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);
3536
}
3637

3738
/** Clean up the native resources associated with this instance */

extension/android/jni/jni_layer.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,15 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
228228
static facebook::jni::local_ref<jhybriddata> initHybrid(
229229
facebook::jni::alias_ref<jclass>,
230230
facebook::jni::alias_ref<jstring> modelPath,
231-
jint loadMode) {
232-
return makeCxxInstance(modelPath, loadMode);
231+
jint loadMode,
232+
jint numThreads) {
233+
return makeCxxInstance(modelPath, loadMode, numThreads);
233234
}
234235

235-
ExecuTorchJni(facebook::jni::alias_ref<jstring> modelPath, jint loadMode) {
236+
ExecuTorchJni(
237+
facebook::jni::alias_ref<jstring> modelPath,
238+
jint loadMode,
239+
jint numThreads) {
236240
Module::LoadMode load_mode = Module::LoadMode::Mmap;
237241
if (loadMode == 0) {
238242
load_mode = Module::LoadMode::File;
@@ -259,11 +263,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
259263
// Based on testing, this is almost universally faster than using all
260264
// cores, as efficiency cores can be quite slow. In extreme cases, using
261265
// all cores can be 10x slower than using cores/2.
262-
//
263-
// TODO Allow overriding this default from Java.
264266
auto threadpool = executorch::extension::threadpool::get_threadpool();
265267
if (threadpool) {
266-
int thread_count = cpuinfo_get_processors_count() / 2;
268+
int thread_count =
269+
numThreads != 0 ? numThreads : cpuinfo_get_processors_count() / 2;
267270
if (thread_count > 0) {
268271
threadpool->_unsafe_reset_threadpool(thread_count);
269272
}

0 commit comments

Comments
 (0)