Skip to content

Commit 518d961

Browse files
committed
Update on "[mps] Disable dialect verifier under mps preprocess"
As title. With reverting dim_order ops, we are producing an illegal IR, which is OK given we are inside MPS preprocess function which shouldn't adhere to edge_ir constraints. Differential Revision: [D73205726](https://our.internmc.facebook.com/intern/diff/D73205726/) [ghstack-poisoned]
2 parents 6ba9773 + 52ba322 commit 518d961

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+930
-228
lines changed

.github/workflows/android-release-artifacts.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ jobs:
8080
8181
echo -n "$SECRET_EXECUTORCH_MAVEN_SIGNING_GPG_KEY_CONTENTS" | base64 -d > /tmp/secring.gpg
8282
83+
# Update the version name in build.gradle in case of maven publish
84+
VERSION="${{ inputs.version }}"
85+
if [ ! -z "$VERSION" ]; then
86+
sed -i "s/\(coordinates(\"org.pytorch\", \"executorch-android\", \"\)\([0-9]\+.[0-9]\+.[0-9]\+\)\(\")\)/\1$VERSION\3/" extension/android/executorch_android/build.gradle
87+
fi
88+
8389
# Build AAR Package
8490
mkdir aar-out
8591
export BUILD_AAR_DIR=aar-out
@@ -92,7 +98,7 @@ jobs:
9298
# Publish to maven staging
9399
UPLOAD_TO_MAVEN="${{ inputs.upload_to_maven }}"
94100
if [[ "$UPLOAD_TO_MAVEN" == "true" ]]; then
95-
(cd aar-out; ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew :executorch_android:publishToMavenCentral)
101+
(cd extension/android; ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew :executorch_android:publishToMavenCentral)
96102
fi
97103
98104
upload-release-aar:

CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -761,12 +761,16 @@ if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
761761
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/flat_tensor)
762762
endif()
763763

764+
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
765+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
766+
endif()
767+
764768
if(EXECUTORCH_BUILD_EXTENSION_LLM)
765769
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/tokenizers)
766770
endif()
767771

768-
if(EXECUTORCH_BUILD_EXTENSION_MODULE)
769-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
772+
if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
773+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/runner)
770774
endif()
771775

772776
if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ executorch
4545
│ └── <a href="devtools/visualization">visualization</a> - Visualization tools for representing model structure and performance metrics.
4646
├── <a href="docs">docs</a> - Static docs tooling and documentation source files.
4747
├── <a href="examples">examples</a> - Examples of various user flows, such as model export, delegates, and runtime execution.
48-
├── <a href="exir">exir</a> - Ahead-of-time library: model capture and lowering APIs. EXport Intermediate Representation (EXIR) is a format for representing the result of <a href="https://pytorch.org/docs/main/export.ir_spec.html">torch.export</a>. This directory contains utilities and passes for lowering the EXIR graphs into different <a href="/docs/source/ir-exir.md">dialects</a> and eventually suitable to run on target hardware.
48+
├── <a href="exir">exir</a> - Ahead-of-time library: model capture and lowering APIs. EXport Intermediate Representation (EXIR) is a format for representing the result of <a href="https://pytorch.org/docs/stable/export.html">torch.export</a>. This directory contains utilities and passes for lowering the EXIR graphs into different <a href="/docs/source/ir-exir.md">dialects</a> and eventually suitable to run on target hardware.
4949
│ ├── <a href="exir/_serialize">_serialize</a> - Serialize final export artifact.
5050
│ ├── <a href="exir/backend">backend</a> - Backend delegate ahead of time APIs.
5151
│ ├── <a href="exir/capture">capture</a> - Program capture.

backends/cadence/aot/replace_ops.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1806,30 +1806,6 @@ def call_operator(self, op, args, kwargs, meta):
18061806
return super().call_operator(op, tuple(new_args), kwargs, meta)
18071807

18081808

1809-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1810-
class ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass(ExportPass):
1811-
"""
1812-
Replace the aten.linalg_vector_norm op with a custom op.
1813-
aten.linalg_vector_norm is not supported by Jarvis, so we
1814-
need to replace it with native_batch_norm at all optimization levels.
1815-
"""
1816-
1817-
def call_operator(self, op, args, kwargs, meta):
1818-
if op != exir_ops.edge.aten.linalg_vector_norm.default:
1819-
return super().call_operator(op, args, kwargs, meta)
1820-
1821-
assert (
1822-
len(args) == 1
1823-
), "aten.linalg_vector_norm should have 1 argument (a tensor), we do not support any custom variants"
1824-
1825-
return super().call_operator(
1826-
exir_ops.edge.cadence.linalg_vector_norm.default,
1827-
args,
1828-
kwargs,
1829-
meta,
1830-
)
1831-
1832-
18331809
@register_cadence_pass(CadencePassAttribute(opt_level=1))
18341810
class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass):
18351811
"""
@@ -2243,7 +2219,6 @@ class CadenceReplaceOpsInGraph:
22432219
ReplacePT2DequantWithCadenceDequantPass,
22442220
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
22452221
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
2246-
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
22472222
ReplaceWhereWithFullArgsWithWhereScalar,
22482223
# ReplaceGeluWithApproximateGeluPass,
22492224
]

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
MakeSliceAndCatDimOutermostPass,
2424
ReplaceAddMMWithLinearPass,
2525
ReplaceAtenConvolutionWithJarvisConvolutionPass,
26-
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
2726
ReplaceConstantPadNdWithSlicePass,
2827
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
2928
ReplaceConvWithIm2RowAndLinear,
@@ -1189,36 +1188,6 @@ def forward(self, x):
11891188
count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0
11901189
)
11911190

1192-
def test_replace_aten_linalg_vector_norm_with_cadence_linalg_vector_norm(self):
1193-
class LinalgVectorNorm(torch.nn.Module):
1194-
def forward(self, x: torch.Tensor):
1195-
return torch.linalg.vector_norm(x)
1196-
1197-
x = torch.randn(32)
1198-
1199-
graph_module = (
1200-
export_to_edge(LinalgVectorNorm(), (x,)).exported_program().graph_module
1201-
)
1202-
1203-
p = ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass()
1204-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1205-
1206-
# Assert that aten.linalg_vector_norm op was replaced by a
1207-
# cadence.linalg_vector_norm op
1208-
self.assertEqual(
1209-
count_node(
1210-
graph_after_passes,
1211-
exir_ops.edge.aten.linalg_vector_norm.default,
1212-
),
1213-
0,
1214-
)
1215-
self.assertEqual(
1216-
count_node(
1217-
graph_after_passes, exir_ops.edge.cadence.linalg_vector_norm.default
1218-
),
1219-
1,
1220-
)
1221-
12221191
def test_replace_aten_where_with_cadence_where_Scalar(self):
12231192
class WhereScalarModel(torch.nn.Module):
12241193
def forward(self, cond: torch.Tensor):

examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#import <executorch/examples/models/llama/runner/runner.h>
1313
#import <executorch/examples/models/llava/runner/llava_runner.h>
1414

15+
using executorch::extension::llm::GenerationConfig;
1516
using executorch::extension::llm::Image;
1617
using executorch::runtime::Error;
1718

@@ -61,8 +62,11 @@ - (BOOL)generate:(NSString*)prompt
6162
sequenceLength:(NSInteger)seq_len
6263
withTokenCallback:(nullable void (^)(NSString*))callback
6364
error:(NSError**)error {
65+
const GenerationConfig config{
66+
.seq_len = static_cast<int32_t>(seq_len)
67+
};
6468
const auto status = _runner->generate(
65-
prompt.UTF8String, seq_len, [callback](const std::string& token) {
69+
prompt.UTF8String, config, [callback](const std::string& token) {
6670
callback(@(token.c_str()));
6771
});
6872
if (status != Error::Ok) {

examples/mediatek/executor_runner/mtk_llama_runner.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,9 @@ bool MTKLlamaRunner::is_loaded() const {
8080

8181
Error MTKLlamaRunner::generate(
8282
const std::string& prompt,
83-
int32_t seq_len,
83+
executorch::extension::llm::GenerationConfig config,
8484
std::function<void(const std::string&)> token_callback,
85-
std::function<void(const Stats&)> stats_callback,
86-
bool echo,
87-
bool warming) {
85+
std::function<void(const Stats&)> stats_callback) {
8886
if (!is_loaded()) {
8987
ET_CHECK_OK_OR_RETURN_ERROR(load());
9088
}

examples/mediatek/executor_runner/mtk_llama_runner.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,9 @@ class MTKLlamaRunner : public executorch::extension::llm::IRunner {
4343
Error load();
4444
Error generate(
4545
const std::string& prompt,
46-
int32_t seq_len = 128,
46+
executorch::extension::llm::GenerationConfig config,
4747
std::function<void(const std::string&)> token_callback = {},
48-
std::function<void(const Stats&)> stats_callback = {},
49-
bool echo = true,
50-
bool warming = false);
48+
std::function<void(const Stats&)> stats_callback = {});
5149
void stop();
5250

5351
LlamaModelOptions get_model_options();

examples/models/llama/main.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ int32_t main(int32_t argc, char** argv) {
5353

5454
const char* prompt = FLAGS_prompt.c_str();
5555

56-
double temperature = FLAGS_temperature;
56+
float temperature = FLAGS_temperature;
5757

5858
int32_t seq_len = FLAGS_seq_len;
5959

@@ -73,13 +73,18 @@ int32_t main(int32_t argc, char** argv) {
7373
}
7474
#endif
7575
// create llama runner
76-
example::Runner runner(model_path, tokenizer_path, temperature);
76+
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
77+
example::Runner runner(model_path, tokenizer_path);
7778

7879
if (warmup) {
79-
runner.warmup(prompt, seq_len);
80+
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
81+
runner.warmup(prompt, /*max_new_tokens=*/seq_len);
8082
}
8183
// generate
82-
runner.generate(prompt, seq_len);
84+
executorch::extension::llm::GenerationConfig config{
85+
.seq_len = seq_len, .temperature = temperature};
86+
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
87+
runner.generate(prompt, config);
8388

8489
return 0;
8590
}

examples/models/llama/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from executorch.examples.models.llama.llama_transformer import Transformer
1919

2020
from executorch.examples.models.llama.model_args import ModelArgs
21+
from torchao.utils import TorchAOBaseTensor
2122

2223
try:
2324
from .fairseq2 import convert_to_llama_checkpoint
@@ -257,6 +258,9 @@ def __init__(self, **kwargs):
257258
strict=False,
258259
assign=True,
259260
) # self.model_ = Transformer(gptconf)
261+
for param in self.model_.parameters():
262+
if isinstance(param, TorchAOBaseTensor):
263+
param.requires_grad = False
260264
else:
261265
print("Checkpoint not provided, defaulting weights to zeros.")
262266
self.model_.to_empty(device="cpu")

0 commit comments

Comments
 (0)