Skip to content

Commit 1b5184d

Browse files
committed
Update base for Update on "Remove llama related stuff out of bpe_tokenizer"
We don't need to initialize `vocab_`, `vocab_scores_`, etc. They will be initialized anyway while loading the tokenizer binary. A benefit of removing them is that we can remove these llama related default values and make `bpe_tokenizer` agnostic to models. Differential Revision: [D59664556](https://our.internmc.facebook.com/intern/diff/D59664556/) [ghstack-poisoned]
2 parents 165c38a + f9efb05 commit 1b5184d

File tree

16 files changed

+203
-20
lines changed

16 files changed

+203
-20
lines changed

backends/transforms/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ runtime.python_library(
120120
"//executorch/backends/...",
121121
"//executorch/examples/...",
122122
"//executorch/extension/llm/...",
123+
"@EXECUTORCH_CLIENTS",
123124
],
124125
deps = [
125126
"//caffe2:torch",

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __contains__(self, op):
4545
BINARY_OPS = [
4646
exir_ops.edge.aten.add.Tensor,
4747
exir_ops.edge.aten.sub.Tensor,
48+
exir_ops.edge.aten.minimum.default,
4849
exir_ops.edge.aten.mul.Tensor,
4950
exir_ops.edge.aten.div.Tensor,
5051
exir_ops.edge.aten.div.Tensor_mode,

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,5 @@ binary_op:
2828
OPERATOR: pow(X, Y)
2929
- NAME: binary_floor_divide
3030
OPERATOR: floor(X / Y)
31+
- NAME: binary_minimum
32+
OPERATOR: min(X, Y)

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ DEFINE_BINARY_OP_WITH_ALPHA_FN(floor_divide);
118118
DEFINE_BINARY_OP_FN(mul);
119119
DEFINE_BINARY_OP_FN(div);
120120
DEFINE_BINARY_OP_FN(pow);
121+
DEFINE_BINARY_OP_FN(minimum);
121122

122123
REGISTER_OPERATORS {
123124
VK_REGISTER_OP(aten.add.Tensor, add);
@@ -126,6 +127,7 @@ REGISTER_OPERATORS {
126127
VK_REGISTER_OP(aten.div.Tensor, div);
127128
VK_REGISTER_OP(aten.div.Tensor_mode, floor_divide);
128129
VK_REGISTER_OP(aten.pow.Tensor_Tensor, pow);
130+
VK_REGISTER_OP(aten.minimum.default, minimum);
129131
}
130132

131133
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,3 +1022,19 @@ def get_constant_pad_nd_inputs():
10221022
]
10231023
)
10241024
return test_suite
1025+
1026+
1027+
@register_test_suite("aten.minimum.default")
1028+
def get_minimum_inputs():
1029+
test_suite = VkTestSuite(
1030+
[
1031+
((M1, M2), (M2)),
1032+
((M1, M2), (M1, M2)),
1033+
((M1, M2, M), (M2, M)),
1034+
((M1, M1, S1, S2), (M1, M1, S1, S2)),
1035+
((S1, S1, S2, S), (S1, S2, S)),
1036+
((M1, S1, S2), (L, M1, S1, S2)),
1037+
((S1, S2), (L, M1, S1, S2)),
1038+
]
1039+
)
1040+
return test_suite

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,25 @@ def forward(self, x):
10721072
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
10731073
)
10741074

1075+
def test_vulkan_backend_minimum(self):
1076+
class MinimumModule(torch.nn.Module):
1077+
def __init__(self):
1078+
super().__init__()
1079+
1080+
def forward(self, x, y):
1081+
return torch.minimum(x, y)
1082+
1083+
sample_inputs = (
1084+
torch.rand(size=(3, 5, 6, 4), dtype=torch.float32),
1085+
torch.rand(size=(6, 4), dtype=torch.float32),
1086+
)
1087+
1088+
self.lower_module_and_test_output(
1089+
MinimumModule(),
1090+
sample_inputs,
1091+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1092+
)
1093+
10751094
def test_vulkan_backend_reshape(self):
10761095
class ReshapeModule(torch.nn.Module):
10771096
def __init__(self):

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,9 @@ def __init__(
869869
self.quant = quant
870870

871871
# TODO(T174256335) - remove this once we have a better way to handle >2d Mask
872-
self._lower_recomposed_sdpa: bool = _lower_recomposed_sdpa or True
872+
self._lower_recomposed_sdpa: bool = (
873+
_lower_recomposed_sdpa if _lower_recomposed_sdpa is not None else True
874+
)
873875

874876
self.delegation_spec = DelegationSpec(XnnpackBackend.__name__, [])
875877
self.partition_tags: Dict[str, DelegationSpec] = {}

examples/models/llama2/tokenizer/llama_tiktoken.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,13 @@ const Encoder LlamaTiktoken::get_special_tokens(ssize_t num_base_tokens) const {
9797
return _get_default_special_tokens(num_base_tokens);
9898
}
9999
}
100+
101+
const std::string LlamaTiktoken::get_bos_token() const {
102+
return "<|begin_of_text|>";
103+
}
104+
105+
const std::string LlamaTiktoken::get_eos_token() const {
106+
return "<|end_of_text|>";
107+
}
100108
} // namespace executor
101109
} // namespace torch

examples/models/llama2/tokenizer/llama_tiktoken.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class LlamaTiktoken : public Tiktoken {
2626

2727
protected:
2828
const Encoder get_special_tokens(ssize_t num_base_tokens) const override;
29+
const std::string get_bos_token() const override;
30+
const std::string get_eos_token() const override;
2931

3032
private:
3133
const Version _version;

examples/models/llama2/tokenizer/tiktoken.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,8 @@ Error Tiktoken::load(const std::string& path) {
346346

347347
// initialize vocab_size, bos_tok, eos_tok
348348
vocab_size_ = _encoder.size() + _special_token_encoder.size();
349-
bos_tok_ = _special_token_encoder.at("<|begin_of_text|>");
350-
eos_tok_ = _special_token_encoder.at("<|end_of_text|>");
349+
bos_tok_ = _special_token_encoder.at(get_bos_token());
350+
eos_tok_ = _special_token_encoder.at(get_eos_token());
351351

352352
initialized_ = true;
353353
return Error::Ok;

examples/models/llama2/tokenizer/tiktoken.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ class Tiktoken : public Tokenizer {
3939
protected:
4040
// Provide model specific special tokens.
4141
virtual const Encoder get_special_tokens(ssize_t num_base_tokens) const = 0;
42+
// Provide beginning of sentence token.
43+
virtual const std::string get_bos_token() const = 0;
44+
// Provide end of sentence token.
45+
virtual const std::string get_eos_token() const = 0;
4246

4347
private:
4448
template <typename T>

exir/memory_planning.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ def verify_graph_input_output(self) -> None:
219219
if _is_mutable_buffer(nd, self.graph_signature):
220220
continue
221221
assert len(specs) > 0, "Expect tensor specs"
222+
specs = list(filter(lambda spec: not spec.const, specs))
223+
if len(specs) == 0:
224+
continue
222225
allocated = any(
223226
spec is None or spec.mem_offset is not None for spec in specs
224227
)

exir/serde/export_serialize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,12 +1058,18 @@ def serialize(self, exported_program: ep.ExportedProgram) -> SerializedArtifact:
10581058
assert n not in constants
10591059
constants[n] = t
10601060

1061+
additional_kwargs = {}
1062+
if hasattr(exported_program, "verifiers"):
1063+
additional_kwargs["verifiers"] = [
1064+
v.dialect for v in exported_program.verifiers
1065+
]
10611066
serialized_ep = ExportedProgram(
10621067
graph_module=serialized_graph_module,
10631068
opset_version=self.opset_version,
10641069
range_constraints=serialized_range_constraints,
10651070
schema_version=SchemaVersion(-1, -1),
10661071
dialect=exported_program.dialect,
1072+
**additional_kwargs,
10671073
)
10681074

10691075
return SerializedArtifact(

exir/serde/serialize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,13 +344,19 @@ def serialize(
344344
assert n not in constants
345345
constants[n] = t
346346

347+
additional_kwargs = {}
348+
if hasattr(exported_program, "verifiers"):
349+
additional_kwargs["verifiers"] = [
350+
v.dialect for v in exported_program.verifiers
351+
]
347352
return export_serialize.SerializedArtifact(
348353
schema.ExportedProgram(
349354
graph_module=serialized_graph_module,
350355
opset_version=self.opset_version,
351356
range_constraints=serialized_range_constraints,
352357
schema_version=SchemaVersion(-1, -1),
353358
dialect=exported_program.dialect,
359+
**additional_kwargs,
354360
),
355361
export_serialize.serialize_torch_artifact(exported_program.state_dict),
356362
export_serialize.serialize_torch_artifact(constants),

kernels/portable/cpu/op_convolution.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,6 @@ void conv2d_impl(
136136
}
137137
}
138138
} else { // transposed convolution
139-
if (bias_ptr != nullptr) {
140-
out_coord[2] = 0;
141-
out_coord[3] = 0;
142-
size_t out_c_start_idx =
143-
calculate_linear_index(out_coord, out_strides.data(), 4);
144-
size_t out_c_end_idx = out_c_start_idx + out_H * out_W;
145-
for (size_t out_ix = out_c_start_idx; out_ix < out_c_end_idx; out_ix++) {
146-
out_ptr[out_ix] = convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
147-
}
148-
}
149-
150139
w_coord[1] = out_c - out_c_start;
151140

152141
for (size_t in_y = 0; in_y < in_H; ++in_y) {
@@ -295,12 +284,22 @@ void convolution_wrapper(
295284
bias.has_value() ? bias.value().const_data_ptr<CTYPE_BIAS>() : nullptr;
296285

297286
size_t out_N = out.size(0);
298-
size_t out_C_per_group = out.size(1) / groups;
287+
size_t out_C = out.size(1);
288+
size_t out_C_per_group = out_C / groups;
299289

300-
if (transposed && bias_ptr == nullptr) {
301-
// If bias is not present, we need to initialize the output to 0
302-
// before we can accumulate into it.
303-
memset(out_ptr, 0, out.nbytes());
290+
if (transposed) {
291+
// For transposed convolution, we need to initialized the output before we
292+
// can accumulate into it.
293+
if (bias_ptr == nullptr) {
294+
// If bias is not present, we need to initialize the output to 0
295+
memset(out_ptr, 0, out.nbytes());
296+
} else {
297+
// If bias is present, we initialize the output to the bias value
298+
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
299+
out_ptr[out_ix] = convert<CTYPE, CTYPE_BIAS>(
300+
bias_ptr[(out_ix / out_strides[1]) % out_C]);
301+
}
302+
}
304303
}
305304

306305
for (size_t batch = 0; batch < out_N; ++batch) {

kernels/test/op_convolution_test.cpp

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ TEST_F(OpConvCorrectnessTest, TransposedNonDefaultParams) {
556556
Tensor input = tf.full({2, 6, 4, 5}, 2.0);
557557
Tensor weight = tf.full({6, 1, 2, 2}, 0.5);
558558
Tensor bias = tf.make({3}, {1, 2, 3});
559-
Tensor out = tf.zeros({2, 3, 3, 6});
559+
Tensor out = tf.full({2, 3, 3, 6}, 0.7);
560560
Tensor expected = tf.make(
561561
{2, 3, 3, 6},
562562
{1, 1, 1, 1, 1, 1, 1, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 2, 2, 2, 2,
@@ -587,6 +587,118 @@ TEST_F(OpConvCorrectnessTest, TransposedNonDefaultParams) {
587587
EXPECT_TENSOR_CLOSE(out, expected);
588588
}
589589

590+
template <typename T>
591+
std::vector<T> get_channels_last_data(const Tensor& t) {
592+
const std::vector<int32_t> sizes(t.sizes().begin(), t.sizes().end());
593+
std::vector<T> contiguous_data(
594+
t.const_data_ptr<T>(), t.const_data_ptr<T>() + t.numel());
595+
std::vector<T> channels_last_data(t.numel());
596+
int32_t N = sizes[0];
597+
int32_t C = sizes[1];
598+
int32_t H = sizes[2];
599+
int32_t W = sizes[3];
600+
for (int32_t n = 0; n < N; ++n) {
601+
for (int32_t c = 0; c < C; ++c) {
602+
for (int32_t h = 0; h < H; ++h) {
603+
for (int32_t w = 0; w < W; ++w) {
604+
// Calculate the index in the original blob
605+
int32_t old_index = ((n * C + c) * H + h) * W + w;
606+
// Calculate the index in the new blob
607+
int32_t new_index = ((n * H + h) * W + w) * C + c;
608+
// Copy the data
609+
channels_last_data[new_index] = contiguous_data[old_index];
610+
}
611+
}
612+
}
613+
}
614+
return channels_last_data;
615+
}
616+
617+
TEST_F(OpConvCorrectnessTest, TransposedDefaultParamsChannelsLast) {
618+
TensorFactory<ScalarType::Float> tf;
619+
620+
Tensor input = tf.full_channels_last({2, 4, 3, 2}, 2.0);
621+
Tensor weight = tf.full_channels_last({4, 1, 2, 2}, 0.5);
622+
optional<Tensor> bias;
623+
Tensor out = tf.full_channels_last({2, 2, 4, 3}, 0.7);
624+
Tensor expected =
625+
tf.make({2, 2, 4, 3}, {2, 4, 2, 4, 8, 4, 4, 8, 4, 2, 4, 2, 2, 4, 2, 4,
626+
8, 4, 4, 8, 4, 2, 4, 2, 2, 4, 2, 4, 8, 4, 4, 8,
627+
4, 2, 4, 2, 2, 4, 2, 4, 8, 4, 4, 8, 4, 2, 4, 2});
628+
629+
const std::vector<int32_t> sizes(
630+
expected.sizes().begin(), expected.sizes().end());
631+
std::vector<float> channels_last_data =
632+
get_channels_last_data<float>(expected);
633+
Tensor expected_channels_last =
634+
tf.make_channels_last(sizes, channels_last_data);
635+
636+
int64_t stride[1] = {1};
637+
int64_t padding[1] = {0};
638+
int64_t dilation[1] = {1};
639+
bool transposed = true;
640+
int64_t output_padding[1] = {0};
641+
int64_t groups = 2;
642+
643+
op_convolution_out(
644+
input,
645+
weight,
646+
exec_aten::optional<Tensor>(bias),
647+
exec_aten::ArrayRef<int64_t>{stride, 1},
648+
exec_aten::ArrayRef<int64_t>{padding, 1},
649+
exec_aten::ArrayRef<int64_t>{dilation, 1},
650+
transposed,
651+
exec_aten::ArrayRef<int64_t>{output_padding, 1},
652+
groups,
653+
out);
654+
655+
EXPECT_TENSOR_CLOSE(out, expected_channels_last);
656+
}
657+
658+
TEST_F(OpConvCorrectnessTest, TransposedNonDefaultParamsChannelsLast) {
659+
TensorFactory<ScalarType::Float> tf;
660+
661+
Tensor input = tf.full_channels_last({2, 6, 4, 5}, 2.0);
662+
Tensor weight = tf.full_channels_last({6, 1, 2, 2}, 0.5);
663+
Tensor bias = tf.make({3}, {1, 2, 3});
664+
Tensor out = tf.full_channels_last({2, 3, 3, 6}, 0.7);
665+
Tensor expected = tf.make(
666+
{2, 3, 3, 6},
667+
{1, 1, 1, 1, 1, 1, 1, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 2, 2, 2, 2,
668+
2, 2, 2, 4, 4, 2, 4, 4, 2, 4, 4, 2, 4, 4, 3, 3, 3, 3, 3, 3, 3, 5,
669+
5, 3, 5, 5, 3, 5, 5, 3, 5, 5, 1, 1, 1, 1, 1, 1, 1, 3, 3, 1, 3, 3,
670+
1, 3, 3, 1, 3, 3, 2, 2, 2, 2, 2, 2, 2, 4, 4, 2, 4, 4, 2, 4, 4, 2,
671+
4, 4, 3, 3, 3, 3, 3, 3, 3, 5, 5, 3, 5, 5, 3, 5, 5, 3, 5, 5});
672+
673+
const std::vector<int32_t> sizes(
674+
expected.sizes().begin(), expected.sizes().end());
675+
std::vector<float> channels_last_data =
676+
get_channels_last_data<float>(expected);
677+
Tensor expected_channels_last =
678+
tf.make_channels_last(sizes, channels_last_data);
679+
680+
int64_t stride[1] = {3};
681+
int64_t padding[1] = {7};
682+
int64_t dilation[1] = {5};
683+
bool transposed = true;
684+
int64_t output_padding[1] = {2};
685+
int64_t groups = 3;
686+
687+
op_convolution_out(
688+
input,
689+
weight,
690+
exec_aten::optional<Tensor>(bias),
691+
exec_aten::ArrayRef<int64_t>{stride, 1},
692+
exec_aten::ArrayRef<int64_t>{padding, 1},
693+
exec_aten::ArrayRef<int64_t>{dilation, 1},
694+
transposed,
695+
exec_aten::ArrayRef<int64_t>{output_padding, 1},
696+
groups,
697+
out);
698+
699+
EXPECT_TENSOR_CLOSE(out, expected_channels_last);
700+
}
701+
590702
TEST_F(OpConvCorrectnessTest, InvalidOutputPadding) {
591703
TensorFactory<ScalarType::Float> tf;
592704

0 commit comments

Comments
 (0)