Skip to content

Commit b634b31

Browse files
committed
Update on "[ET-VK] Fixing conv2d dw incorrect output when stride != dilation issue."
This diff moves current implementation of conv2d dw as a special case when stride equals dilation in the Vulkan backend of Executorch, since that's the only time this kind of caching is possible. If stride does not equal dilation the old implementation is used. Additional test cases are added to ensure computation is correct when stride != dilation. Differential Revision: [D67908916](https://our.internmc.facebook.com/intern/diff/D67908916/) [ghstack-poisoned]
2 parents 5048523 + f027deb commit b634b31

File tree

20 files changed

+620
-60
lines changed

20 files changed

+620
-60
lines changed

backends/cadence/fusion_g3/operators/op_mean.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ int prepare_data(
5959
return num_axis_dims;
6060
}
6161

62-
Tensor& mean_dim_out(
62+
Tensor& mean_out(
6363
KernelRuntimeContext& ctx,
6464
const Tensor& in,
6565
optional<ArrayRef<int64_t>> dim_list,
@@ -199,4 +199,4 @@ Tensor& mean_dim_out(
199199
} // namespace native
200200
} // namespace G3
201201
} // namespace impl
202-
} // namespace cadence
202+
} // namespace cadence

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
${define_active_storage_type(STORAGE)}
1717

1818
${define_required_extensions(DTYPE)}
19-
${define_required_extensions("int8")}
19+
$if STORAGE == "buffer":
20+
${define_required_extensions("int8")}
2021

2122
#include "indexing_utils.h"
2223

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
${define_active_storage_type(STORAGE)}
1717

1818
${define_required_extensions(DTYPE)}
19-
${define_required_extensions("int8")}
19+
$if STORAGE == "buffer":
20+
${define_required_extensions("int8")}
2021

2122

2223
$if BATCH_MODE:

examples/models/llama/export_llama_lib.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch
2424

2525
from executorch.devtools.etrecord import generate_etrecord
26+
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
2627

2728
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
2829

@@ -775,6 +776,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
775776
for partitioner in partitioners:
776777
logging.info(f"--> {partitioner.__class__.__name__}")
777778

779+
additional_passes = []
780+
if args.model in TORCHTUNE_DEFINED_MODELS:
781+
additional_passes = [InitializedMutableBufferPass(["cache_pos"])]
778782
if args.generate_etrecord:
779783
if not builder_exported_to_edge.edge_manager:
780784
raise ValueError("Unable to generate etrecord due to missing edge manager.")
@@ -789,7 +793,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
789793
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
790794
canonicalize_program(builder.edge_manager.exported_program())
791795

792-
builder = builder.to_executorch()
796+
builder = builder.to_executorch(
797+
passes=additional_passes,
798+
)
793799

794800
# Generate ETRecord
795801
if edge_manager_copy:
@@ -807,7 +813,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
807813
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
808814
canonicalize_program(builder.edge_manager.exported_program())
809815

810-
builder = builder.to_executorch()
816+
builder = builder.to_executorch(passes=additional_passes)
811817

812818
if args.profile_memory:
813819
generate_memory_trace(builder.export_program, "memory_profile.json")

examples/models/llama3_2_vision/runner/native.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
TorchTuneLlamaRunner,
1919
)
2020

21-
from executorch.extension.pybindings.portable_lib import _load_for_executorch
21+
from executorch.extension.pybindings.portable_lib import (
22+
_load_for_executorch_from_buffer,
23+
)
2224

2325
# Load custom ops and quantized ops.
2426
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
2527

2628
# Note: import this after portable_lib
27-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
29+
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
2830
from executorch.kernels import quantized # noqa
2931

3032

@@ -43,7 +45,17 @@ def __init__(self, args):
4345
use_kv_cache=args.kv_cache,
4446
vocab_size=params["vocab_size"],
4547
)
46-
self.model = _load_for_executorch(args.pte)
48+
# Save the loaded model bytes to prevent data from going out of
49+
# scope after the `with` and getting cleaned up by Python's
50+
# garbage collector.
51+
self.model_bytes = None
52+
with open(args.pte, "rb") as f:
53+
self.model_bytes = f.read()
54+
# Need to use _load_for_executorch_from_buffer instead of
55+
# _load_for_executorch because the latter uses MmapDataLoader,
56+
# which doesn't have load_into() implemented, which is needed
57+
# for loading initialized mutable buffers.
58+
self.model = _load_for_executorch_from_buffer(self.model_bytes)
4759
self.use_kv_cache = args.kv_cache
4860

4961
def forward(

exir/emit/_emitter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,7 +1575,8 @@ def _find_fqn_for_placeholder(
15751575
warnings.warn(
15761576
"Mutation on a buffer in the model is detected. ExecuTorch assumes "
15771577
"buffers that are mutated in the graph have a meaningless initial state, "
1578-
"only the shape and dtype will be serialized.",
1578+
"only the shape and dtype will be serialized, unless a pass which sets "
1579+
'meta["et_init_buffer"] to True such as InitializedMutableBufferPass is run.',
15791580
UserWarning,
15801581
stacklevel=1,
15811582
)
@@ -1602,6 +1603,7 @@ def placeholder(
16021603
"""
16031604
spec = self.node.meta["spec"]
16041605
constant_tag = self.node.meta.get("constant_tag", None)
1606+
initialize_buffer = self.node.meta.get("et_init_buffer", None)
16051607
is_user_input = True
16061608

16071609
if isinstance(target, str) and isinstance(spec, TensorSpec):
@@ -1655,7 +1657,10 @@ def placeholder(
16551657
spec.storage = real_tensor.untyped_storage()
16561658

16571659
# User inputs and mutable buffers are not constants, other buffers or parameters are.
1658-
spec.const = not (is_user_input or is_mutable_buffer)
1660+
if initialize_buffer and is_mutable_buffer:
1661+
spec.const = True
1662+
else:
1663+
spec.const = not (is_user_input or is_mutable_buffer)
16591664

16601665
evalue = (
16611666
self._tensor_spec_to_evalue(spec, constant_tag)

exir/emit/test/test_emit.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import typing
1010
import unittest
1111
from contextlib import contextmanager
12+
from copy import deepcopy
1213
from typing import List, Optional, Tuple
1314

1415
import executorch.exir as exir
@@ -31,6 +32,7 @@
3132
from executorch.exir.error import InternalError
3233
from executorch.exir.passes import MemoryPlanningPass
3334
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
35+
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
3436
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
3537
from executorch.exir.print_program import pretty_print, print_program # noqa
3638
from executorch.exir.schema import (
@@ -56,6 +58,7 @@
5658
from executorch.extension.pybindings.portable_lib import (
5759
_load_for_executorch_from_buffer,
5860
)
61+
from executorch.runtime import Runtime
5962

6063
from functorch.experimental import control_flow
6164
from torch import nn
@@ -243,6 +246,56 @@ def forward(self, x):
243246
)
244247
self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null)
245248

249+
def test_initialized_mutable_buffer(self):
250+
"""Test that mutable buffers can hold meaningful initialized state."""
251+
252+
class TestModule(torch.nn.Module):
253+
def __init__(self):
254+
super().__init__()
255+
# Mutable buffer with non-empty initial state.
256+
self.register_buffer("cache_pos", torch.arange(0, 10))
257+
258+
def forward(self, x):
259+
self.cache_pos.add_(1)
260+
return self.cache_pos
261+
262+
m = TestModule()
263+
example_inputs = (torch.ones(10),)
264+
ep = torch.export.export(m, example_inputs)
265+
edge = to_edge(
266+
ep,
267+
compile_config=EdgeCompileConfig(
268+
_check_ir_validity=False,
269+
),
270+
)
271+
272+
# Save a copy of the edge program since to_executorch is
273+
# stateful to some degree.
274+
edge_copy = deepcopy(edge)
275+
et_config = ExecutorchBackendConfig(
276+
passes=[InitializedMutableBufferPass(["cache_pos"])],
277+
)
278+
et_program_init_pass = edge.to_executorch(config=et_config)
279+
et_program_regular = edge_copy.to_executorch()
280+
281+
runtime = Runtime.get()
282+
program_init_pass = runtime.load_program(et_program_init_pass.buffer)
283+
method_init_pass = program_init_pass.load_method("forward")
284+
285+
program_regular = runtime.load_program(et_program_regular.buffer)
286+
method_regular = program_regular.load_method("forward")
287+
288+
# Test that the mutable buffer is initialized.
289+
torch.allclose(
290+
method_init_pass.execute((example_inputs))[0], torch.arange(1, 11)
291+
)
292+
# Test that the mutable buffer is uninitialized and starts with default zeros,
293+
# we test equality with torch.ones because of the mutation += 1 in the model forward.
294+
torch.allclose(
295+
method_regular.execute((example_inputs))[0],
296+
torch.ones(10, dtype=torch.int64),
297+
)
298+
246299
def test_int_list_input(self):
247300
class M(torch.nn.Module):
248301
def forward(self, x, y, z):

exir/passes/init_mutable_pass.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import List
9+
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
class InitializedMutableBufferPass(ExportPass):
14+
"""
15+
If a buffer has a name that within a specified list, set meta["et_init_buffer"]
16+
to True, which provides the mutable buffer with an initialized state.
17+
18+
As an example, a module with `self.register_buffer("cache_pos", torch.arange(10))`
19+
when patterns = ["cache_pos"] would have its initial state set instead of being
20+
left uninitialized by default.
21+
"""
22+
23+
def __init__(self, patterns: List[str]) -> None:
24+
super().__init__()
25+
self.patterns = patterns
26+
27+
def placeholder(self, name: str, arg, meta):
28+
for pattern in self.patterns:
29+
if pattern in name:
30+
meta["et_init_buffer"] = True
31+
32+
return super().placeholder(name, arg, meta)

extension/llm/export/builder.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from executorch.exir.backend.utils import format_delegated_graph
2828
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
2929

30+
from executorch.exir.pass_manager import PassType
3031
from executorch.exir.passes import MemoryPlanningPass
3132
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
3233
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
@@ -415,21 +416,27 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag
415416

416417
return self
417418

418-
def to_executorch(self) -> "LLMEdgeManager":
419+
def to_executorch(
420+
self, passes: Optional[List[PassType]] = None
421+
) -> "LLMEdgeManager":
419422
"""
420423
Lower the model to executorch and get an ExecutorchProgram.
421424
"""
422425
assert self.edge_manager, "Need to run export_to_edge() first"
426+
to_executorch_passes = [
427+
# If there are Linear operations left in the graph, let's execute
428+
# them with the optimized op_linear rather than materializing a
429+
# transpose followed by a regular op_mm.
430+
ConvertToLinearPass(),
431+
QuantFusionPass(),
432+
]
433+
if passes:
434+
to_executorch_passes.extend(passes)
435+
423436
self.export_program = self.edge_manager.to_executorch(
424437
ExecutorchBackendConfig(
425438
extract_delegate_segments=True,
426-
passes=[
427-
# If there are Linear operations left in the graph, let's execute
428-
# them with the optimized op_linear rather than materializing a
429-
# transpose followed by a regular op_mm.
430-
ConvertToLinearPass(),
431-
QuantFusionPass(),
432-
],
439+
passes=to_executorch_passes,
433440
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
434441
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
435442
)

0 commit comments

Comments
 (0)