Skip to content

Commit d0bc16e

Browse files
committed
Update
[ghstack-poisoned]
2 parents accd815 + 80d5e5a commit d0bc16e

File tree

7 files changed

+89
-24
lines changed

7 files changed

+89
-24
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ def _get_bias_deps(
210210
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
211211
) -> Tuple[bool, List[torch.fx.Node]]:
212212
gemm_deps = []
213+
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
214+
# if force force_fp32_dynamic_linear is enabled, then we
215+
# do not partition the weight node
216+
return (True, gemm_deps)
217+
213218
if len(node.all_input_nodes) > 2 and self.bias_idx is not None:
214219
bias_node = get_input_node(node, self.bias_idx)
215220
if bias_node:
@@ -477,7 +482,15 @@ def find_partition_args(input_node):
477482
node.args = old_args
478483
node.users = old_users
479484

480-
return valid_deps, list(set(deps) | set(src_partition.nodes))
485+
# When using force_fp32_dynamic_linear, we want to get_deps to overwrite the source partition nodes.
486+
# Else we want to be greedy.
487+
ret_deps = (
488+
list(set(deps) & set(src_partition.nodes))
489+
if self.force_fp32_dynamic_linear
490+
else list(set(deps) | set(src_partition.nodes))
491+
)
492+
493+
return valid_deps, ret_deps
481494

482495
def supported_precision_types(self):
483496
return [

backends/xnnpack/test/ops/test_linear.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
ToEdgeTransformAndLower,
3232
)
3333

34+
from torch.export.graph_signature import ExportGraphSignature, InputKind
35+
3436
try:
3537
from torchao.quantization.quant_api import (
3638
int8_dynamic_activation_int4_weight,
@@ -871,3 +873,71 @@ def test_linear_qd8_as_fp32(self):
871873
"dequantize_per_channel.default": 1, # 1: weight
872874
},
873875
)
876+
877+
def test_linear_fp32_with_force_as_mm(self):
878+
def check_signature(
879+
signature: ExportGraphSignature,
880+
force_flag: bool,
881+
use_bias: bool,
882+
legacy_mode: bool,
883+
):
884+
num_params = 0
885+
if force_flag:
886+
num_params = 1 # weight_param
887+
if use_bias:
888+
num_params += 1 # bias_param
889+
sign_params: int = 0
890+
input_specs = signature.input_specs
891+
for input_spec in input_specs:
892+
if input_spec.kind == InputKind.PARAMETER:
893+
sign_params += 1
894+
assert (
895+
sign_params == num_params
896+
), f"Expected {num_params} params, got {sign_params} with force_flag={force_flag}, use_bias={use_bias}, legacy_mode={legacy_mode}"
897+
898+
for force_flag in (True, False):
899+
for use_bias in (True, False):
900+
for legacy_mode in (True, False):
901+
module = BaseLinear(
902+
in_size=8,
903+
input_channels=13,
904+
output_channels=17,
905+
use_bias=use_bias,
906+
)
907+
inputs = module.get_inputs()
908+
tester = Tester(module, inputs).export()
909+
partitioner = XnnpackPartitioner(
910+
force_fp32_dynamic_linear=force_flag
911+
)
912+
if legacy_mode:
913+
tester.to_edge()
914+
partitioner_stage = Partition(partitioner=partitioner)
915+
tester.partition(partition_stage=partitioner_stage)
916+
tester.check_not(
917+
[
918+
(
919+
"executorch_exir_dialects_edge__ops_aten_mm_default"
920+
if use_bias
921+
else "executorch_exir_dialects_edge__ops_aten_addmm_default"
922+
)
923+
]
924+
)
925+
else:
926+
to_edge_and_transform_stage = ToEdgeTransformAndLower(
927+
partitioners=[partitioner]
928+
)
929+
tester.to_edge_transform_and_lower(
930+
to_edge_and_transform_stage=to_edge_and_transform_stage
931+
)
932+
tester.check_not(
933+
["executorch_exir_dialects_edge__ops_aten_linear_default"]
934+
)
935+
936+
signature: ExportGraphSignature = (
937+
tester.get_artifact().exported_program().graph_signature
938+
)
939+
check_signature(signature, force_flag, use_bias, legacy_mode)
940+
941+
tester.to_executorch()
942+
tester.serialize()
943+
tester.run_method_and_compare_outputs()

backends/xnnpack/test/ops/test_lstm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,8 @@ def test_fp32_lstm_force_dynamic_linear(self):
5454
)
5555
.check_not(["executorch_exir_dialects_edge__ops_aten_addmm_default"])
5656
# Weights are supplied as input to linears
57-
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0"])
58-
# Biases are owned by delegates
59-
.check_not(["p_lstm_bias"])
57+
# Biases are not owned by delegates when force_fp32_dynamic_linear is set
58+
.check(["p_lstm_weight_hh_l0", "p_lstm_weight_ih_l0", "p_lstm_bias"])
6059
.to_executorch()
6160
.serialize()
6261
.run_method_and_compare_outputs()

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,12 @@ python_library(
3535

3636
python_binary(
3737
name = "llama",
38-
srcs = ["llama.py"],
3938
main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main",
4039
preload_deps = [
4140
"//executorch/extension/llm/custom_ops:model_sharding_py",
4241
],
4342
deps = [
44-
"//executorch/examples/qualcomm/oss_scripts/llama:static_llama",
45-
"//caffe2:torch",
46-
"//executorch/extension/pybindings:aten_lib",
47-
"//executorch/backends/qualcomm/partition:partition",
48-
"//executorch/backends/qualcomm/quantizer:quantizer",
49-
"//executorch/devtools/backend_debug:delegation_info",
50-
"//executorch/devtools:lib",
51-
"//executorch/examples/models:models",
52-
"//executorch/examples/qualcomm:utils",
53-
"//executorch/extension/export_util:export_util",
54-
"//executorch/extension/llm/export:export_lib",
43+
":llama_lib",
5544
],
5645
)
5746

runtime/core/portable_type/c10/c10/macros/Export.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,10 @@
139139
#endif
140140

141141
#if defined(TORCH_HIP_BUILD_MAIN_LIB)
142+
#define TORCH_HIP_CPP_API C10_EXPORT
142143
#define TORCH_HIP_API C10_EXPORT
143144
#else
145+
#define TORCH_HIP_CPP_API C10_IMPORT
144146
#define TORCH_HIP_API C10_IMPORT
145147
#endif
146148

runtime/core/portable_type/c10/c10/util/BFloat16.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
#include <cstdint>
99
#include <cstring>
1010
#include <iosfwd>
11-
#ifndef C10_EMBEDDED
1211
#include <ostream>
13-
#endif // C10_EMBEDDED
1412

1513
#if defined(__CUDACC__) && !defined(USE_ROCM)
1614
#include <cuda_bf16.h>
@@ -116,14 +114,12 @@ struct alignas(2) BFloat16 {
116114
#endif
117115
};
118116

119-
#ifndef C10_EMBEDDED
120117
C10_API inline std::ostream& operator<<(
121118
std::ostream& out,
122119
const BFloat16& value) {
123120
out << (float)value;
124121
return out;
125122
}
126-
#endif // C10_EMBEDDED
127123

128124
} // namespace c10
129125

runtime/core/portable_type/c10/c10/util/Half.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
#include <cstring>
3030
#include <iosfwd>
3131
#include <limits>
32-
#ifndef C10_EMBEDDED
3332
#include <ostream>
34-
#endif // C10_EMBEDDED
3533

3634
#ifdef __CUDACC__
3735
#include <cuda_fp16.h>
@@ -411,12 +409,10 @@ struct alignas(2) Half {
411409
#endif
412410
};
413411

414-
#ifndef C10_EMBEDDED
415412
C10_API inline std::ostream& operator<<(std::ostream& out, const Half& value) {
416413
out << (float)value;
417414
return out;
418415
}
419-
#endif // C10_EMBEDDED
420416

421417
} // namespace c10
422418

0 commit comments

Comments
 (0)