Skip to content

Commit 053b6ce

Browse files
committed
Update
[ghstack-poisoned]
2 parents 0dce3a8 + b5c0c61 commit 053b6ce

33 files changed

+448
-352
lines changed

.ci/scripts/unittest-linux.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ else
1515
fi
1616

1717
# The generic Linux job chooses to use base env, not the one setup by the image
18+
eval "$(conda shell.bash hook)"
1819
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
1920
conda activate "${CONDA_ENV}"
2021

.ci/scripts/unittest-macos.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ else
1515
fi
1616

1717
bash .ci/scripts/setup-conda.sh
18+
eval "$(conda shell.bash hook)"
1819

1920
# Create temp directory for sccache shims
2021
export TMP_DIR=$(mktemp -d)

backends/arm/arm_vela.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ def vela_bin_pack_io(prefix, data, shape_order=None):
3939
# Output via Vela to binary stream for ArmBackendEthosU
4040
# WARNING: Do not change this without changing VelaBinStream.cpp as that
4141
# function consumes this format and the two need to align.
42-
def vela_compile(tosa_flatbuffer: bytes, args: List[str], shape_order=None):
42+
def vela_compile(
43+
tosa_flatbuffer: bytes, args: List[str], shape_order=None, verbose: bool = False
44+
):
45+
"""
46+
Compile a TOSA graph to a binary stream for ArmBackendEthosU using Vela.
47+
"""
4348
with tempfile.TemporaryDirectory() as tmpdir:
4449
tosaname = "out.tosa"
4550
tosa_path = os.path.join(tmpdir, tosaname)
@@ -50,6 +55,8 @@ def vela_compile(tosa_flatbuffer: bytes, args: List[str], shape_order=None):
5055
output_dir = os.path.join(tmpdir, "output")
5156
args.append(f"--output-dir={output_dir}")
5257
args.append(tosa_path)
58+
if verbose:
59+
args.append("--verbose-all")
5360
vela.main(" ".join(args).split(" "))
5461

5562
if any("ethos-u85" in arg for arg in args) or any(

backends/arm/ethosu_backend.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ def _compile_tosa_flatbuffer(
5858
)
5959

6060
# Pass on the TOSA flatbuffer to the vela compiler.
61-
binary = vela_compile(tosa_flatbuffer, compile_flags, input_order)
61+
binary = vela_compile(
62+
tosa_flatbuffer,
63+
compile_flags,
64+
input_order,
65+
verbose=logger.getEffectiveLevel() == logging.INFO,
66+
)
6267
return binary
6368

6469
@staticmethod

backends/arm/operators/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ python_library(
2121
"//executorch/backends/arm:tosa_mapping",
2222
"//executorch/backends/arm:tosa_quant_utils",
2323
"//executorch/backends/arm:tosa_utils",
24+
"//executorch/backends/arm/_passes:passes",
2425
"//executorch/exir:lib",
2526
],
2627
)

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -594,46 +594,46 @@ bool validate_flash_attention_args(
594594
const Tensor& key,
595595
const Tensor& value,
596596
const optional<Tensor>& attn_mask) {
597-
ET_LOG_MSG_AND_RETURN_IF_FALSE(query.dim() == 4, "query must be a 4D tensor");
598-
ET_LOG_MSG_AND_RETURN_IF_FALSE(key.dim() == 4, "key must be a 4D tensor");
599-
ET_LOG_MSG_AND_RETURN_IF_FALSE(value.dim() == 4, "value must be a 4D tensor");
597+
ET_CHECK_OR_RETURN_FALSE(query.dim() == 4, "query must be a 4D tensor");
598+
ET_CHECK_OR_RETURN_FALSE(key.dim() == 4, "key must be a 4D tensor");
599+
ET_CHECK_OR_RETURN_FALSE(value.dim() == 4, "value must be a 4D tensor");
600600

601601
// Sizes
602-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
602+
ET_CHECK_OR_RETURN_FALSE(
603603
(query.size(3) == value.size(3)) && (key.size(3) == value.size(3)),
604604
"scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size");
605605

606-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
606+
ET_CHECK_OR_RETURN_FALSE(
607607
(query.scalar_type() == ScalarType::Float), "Query must be Float type");
608608

609-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
609+
ET_CHECK_OR_RETURN_FALSE(
610610
(query.scalar_type() == key.scalar_type()) &&
611611
(query.scalar_type() == value.scalar_type()),
612612
"Key and Value must have the same data type as Query");
613613

614-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
614+
ET_CHECK_OR_RETURN_FALSE(
615615
!attn_mask.has_value() || attn_mask.value().dim() == 2,
616616
"Attention mask must be a 2D tensor");
617617

618-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
618+
ET_CHECK_OR_RETURN_FALSE(
619619
!attn_mask.has_value() ||
620620
attn_mask.value().scalar_type() == query.scalar_type(),
621621
"Attention mask must be a 2D tensor");
622622

623-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
623+
ET_CHECK_OR_RETURN_FALSE(
624624
is_contiguous_dim_order(query.dim_order().data(), query.dim()),
625625
"key cache must be in contiguous dim order");
626626

627-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
627+
ET_CHECK_OR_RETURN_FALSE(
628628
is_contiguous_dim_order(key.dim_order().data(), key.dim()),
629629
"value cache must be in contiguous dim order");
630630

631-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
631+
ET_CHECK_OR_RETURN_FALSE(
632632
is_contiguous_dim_order(value.dim_order().data(), value.dim()),
633633
"value cache must be in contiguous dim order");
634634

635635
if (attn_mask.has_value()) {
636-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
636+
ET_CHECK_OR_RETURN_FALSE(
637637
is_contiguous_dim_order(
638638
attn_mask.value().dim_order().data(), attn_mask.value().dim()),
639639
"value cache must be in contiguous dim order");
@@ -647,21 +647,19 @@ bool validate_cache_params(
647647
const Tensor& v_cache,
648648
int64_t start_pos,
649649
int64_t seq_length) {
650-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
651-
k_cache.dim() == 4, "kcache must be a 4D tensor");
650+
ET_CHECK_OR_RETURN_FALSE(k_cache.dim() == 4, "kcache must be a 4D tensor");
652651

653-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
654-
v_cache.dim() == 4, "v_cache must be a 4D tensor");
652+
ET_CHECK_OR_RETURN_FALSE(v_cache.dim() == 4, "v_cache must be a 4D tensor");
655653

656-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
654+
ET_CHECK_OR_RETURN_FALSE(
657655
start_pos < k_cache.size(1),
658656
"start_pos must be less than key cache at dim 1");
659657

660-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
658+
ET_CHECK_OR_RETURN_FALSE(
661659
start_pos < v_cache.size(1),
662660
"start_pos must be less than value cache at dim 1");
663661

664-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
662+
ET_CHECK_OR_RETURN_FALSE(
665663
(start_pos + seq_length) <= k_cache.size(1),
666664
"start_post + seq_length must be less than max seq length supported by key cache."
667665
"start pos: %" PRId64 ", seq_length: %" PRId64
@@ -671,7 +669,7 @@ bool validate_cache_params(
671669
seq_length,
672670
k_cache.size(1));
673671

674-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
672+
ET_CHECK_OR_RETURN_FALSE(
675673
(start_pos + seq_length) <= v_cache.size(1),
676674
"start_post + seq_length must be less than max seq length supported by key cache."
677675
"start pos: %" PRId64 ", seq_length: %" PRId64
@@ -682,11 +680,11 @@ bool validate_cache_params(
682680
v_cache.size(1));
683681

684682
// Make sure they are in contiguous dim order
685-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
683+
ET_CHECK_OR_RETURN_FALSE(
686684
is_contiguous_dim_order(k_cache.dim_order().data(), k_cache.dim()),
687685
"key cache must be in contiguous dim order");
688686

689-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
687+
ET_CHECK_OR_RETURN_FALSE(
690688
is_contiguous_dim_order(v_cache.dim_order().data(), v_cache.dim()),
691689
"value cache must be in contiguous dim order");
692690

extension/llm/custom_ops/op_update_cache.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,17 @@ bool validate_cache_params(
2525
const Tensor& quantized_cache,
2626
int64_t start_pos,
2727
int64_t seq_length) {
28-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
28+
ET_CHECK_OR_RETURN_FALSE(
2929
quantized_cache.dim() == 4, "quantized cache must be a 4D tensor");
3030

31-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
31+
ET_CHECK_OR_RETURN_FALSE(
3232
quantized_value.dim() == 4, "quantized_value must be a 4D tensor");
3333

34-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
34+
ET_CHECK_OR_RETURN_FALSE(
3535
start_pos < quantized_cache.size(1),
3636
"start_pos must be less than cache size at dim 1");
3737

38-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
38+
ET_CHECK_OR_RETURN_FALSE(
3939
(start_pos + seq_length) <= quantized_cache.size(1),
4040
"start_post + seq_length must be less than max seq length supported by cache."
4141
"start pos: %" PRId64 ", seq_length: %" PRId64
@@ -46,12 +46,12 @@ bool validate_cache_params(
4646
quantized_cache.size(1));
4747

4848
// Make sure they are in contiguous dim order
49-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
49+
ET_CHECK_OR_RETURN_FALSE(
5050
is_contiguous_dim_order(
5151
quantized_cache.dim_order().data(), quantized_cache.dim()),
5252
"quantized cache must be in contiguous dim order");
5353

54-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
54+
ET_CHECK_OR_RETURN_FALSE(
5555
is_contiguous_dim_order(
5656
quantized_value.dim_order().data(), quantized_value.dim()),
5757
"quantized value must be in contiguous dim order");

extension/llm/tokenizer/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def define_common_targets():
1010
name = "tokenizer_py_lib",
1111
srcs = [
1212
"__init__.py",
13+
"hf_tokenizer.py",
1314
"tokenizer.py",
1415
"utils.py",
1516
],

0 commit comments

Comments
 (0)