Skip to content

Commit baa882a

Browse files
committed
Merge branch 'master' into xsn/qwen25omni
2 parents 94d893d + 4f81b33 commit baa882a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1329
-529
lines changed

common/arg.cpp

Lines changed: 127 additions & 98 deletions
Large diffs are not rendered by default.

common/chat-parser.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,23 @@ std::string common_chat_msg_parser::consume_rest() {
170170
}
171171

172172
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
173-
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
173+
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) {
174174
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
175175
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
176176
return std::nullopt;
177177
}
178+
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
179+
pos_ = m.groups[0].end;
180+
181+
if (add_prelude_to_content) {
182+
add_content(prelude);
183+
}
178184
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
179185
if (is_partial()) {
180186
throw common_chat_msg_partial_exception(regex.str());
181187
}
182188
return std::nullopt;
183189
}
184-
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
185-
pos_ = m.groups[0].end;
186-
187190
return find_regex_result{prelude, m.groups};
188191
}
189192

common/chat-parser.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class common_chat_msg_parser {
3030
const std::string & healing_marker() const { return healing_marker_; }
3131
const bool & is_partial() const { return is_partial_; }
3232
const common_chat_msg & result() const { return result_; }
33+
const common_chat_syntax & syntax() const { return syntax_; }
3334

3435
void move_to(size_t pos) {
3536
if (pos > input_.size()) {
@@ -77,7 +78,7 @@ class common_chat_msg_parser {
7778
std::vector<common_string_range> groups;
7879
};
7980

80-
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
81+
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
8182

8283
bool try_consume_literal(const std::string & literal);
8384

common/chat.cpp

Lines changed: 181 additions & 131 deletions
Large diffs are not rendered by default.

common/chat.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ struct common_chat_templates_inputs {
123123
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
124124
bool parallel_tool_calls = false;
125125
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
126+
bool enable_thinking = true;
126127
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
127128
};
128129

@@ -143,6 +144,7 @@ struct common_chat_syntax {
143144
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
144145
bool reasoning_in_content = false;
145146
bool thinking_forced_open = false;
147+
bool parse_tool_calls = true;
146148
};
147149

148150
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
@@ -181,7 +183,8 @@ std::string common_chat_format_example(
181183
const struct common_chat_templates * tmpls,
182184
bool use_jinja);
183185

184-
std::string common_chat_format_name(common_chat_format format);
186+
const char* common_chat_format_name(common_chat_format format);
187+
const char* common_reasoning_format_name(common_reasoning_format format);
185188
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
186189

187190
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ struct common_params {
291291
int32_t verbosity = 0;
292292
int32_t control_vector_layer_start = -1; // layer range for control vector
293293
int32_t control_vector_layer_end = -1; // layer range for control vector
294+
bool offline = false;
294295

295296
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
296297
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
@@ -368,6 +369,7 @@ struct common_params {
368369
bool use_jinja = false; // NOLINT
369370
bool enable_chat_template = true;
370371
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
372+
int reasoning_budget = -1;
371373
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
372374

373375
std::vector<std::string> api_keys;

docs/backend/CANN.md

100644100755
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,15 @@ cmake --build build --config release
280280
### **GitHub contribution**:
281281
Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay.
282282

283+
## Updates
284+
### Basic Flash Attention Support
285+
The basic FA kernel with aclnnops has been added in aclnn_ops.cpp.
286+
Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap.
287+
Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the quantized version in the future.
288+
289+
Authors from Peking University: Bizhao Shi ([email protected]), Yuxin Yang ([email protected]), Ruiyang Ma ([email protected]), and Guojie Luo ([email protected]).
290+
291+
We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers from Huawei Technologies Co., Ltd for their help during the code development and pull request.
283292

284293
## TODO
285294
- Support more models and data types.

examples/embedding/embedding.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
4141

4242
// run model
4343
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
44-
if (llama_encode(ctx, batch) < 0) {
45-
LOG_ERR("%s : failed to encode\n", __func__);
44+
if (llama_decode(ctx, batch) < 0) {
45+
LOG_ERR("%s : failed to process\n", __func__);
4646
}
4747

4848
for (int i = 0; i < batch.n_tokens; i++) {

examples/retrieval/retrieval.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
8181
}
8282
}
8383

84-
static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
84+
static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
8585
// clear previous kv_cache values (irrelevant for embeddings)
8686
llama_kv_self_clear(ctx);
8787

8888
// run model
8989
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
90-
if (llama_encode(ctx, batch) < 0) {
91-
LOG_ERR("%s : failed to encode\n", __func__);
90+
if (llama_decode(ctx, batch) < 0) {
91+
LOG_ERR("%s : failed to process\n", __func__);
9292
}
9393

9494
for (int i = 0; i < batch.n_tokens; i++) {
@@ -233,7 +233,7 @@ int main(int argc, char ** argv) {
233233
// encode if at capacity
234234
if (batch.n_tokens + n_toks > n_batch) {
235235
float * out = emb + p * n_embd;
236-
batch_encode(ctx, batch, out, s, n_embd);
236+
batch_process(ctx, batch, out, s, n_embd);
237237
common_batch_clear(batch);
238238
p += s;
239239
s = 0;
@@ -246,7 +246,7 @@ int main(int argc, char ** argv) {
246246

247247
// final batch
248248
float * out = emb + p * n_embd;
249-
batch_encode(ctx, batch, out, s, n_embd);
249+
batch_process(ctx, batch, out, s, n_embd);
250250

251251
// save embeddings to chunks
252252
for (int i = 0; i < n_chunks; i++) {
@@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
267267
batch_add_seq(query_batch, query_tokens, 0);
268268

269269
std::vector<float> query_emb(n_embd, 0);
270-
batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd);
270+
batch_process(ctx, query_batch, query_emb.data(), 1, n_embd);
271271

272272
common_batch_clear(query_batch);
273273

examples/training/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ Proof of concept:
1010

1111
``` sh
1212
export model_name=llama_3.2-1b && export quantization=f32
13-
./build/bin/finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
14-
./build/bin/perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
13+
./build/bin/llama-finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
14+
./build/bin/llama-perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
1515
```
1616

1717
The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs.

ggml/src/ggml-cann/CMakeLists.txt

100644100755
File mode changed.

ggml/src/ggml-cann/Doxyfile

100644100755
File mode changed.

ggml/src/ggml-cann/acl_tensor.cpp

100644100755
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
3131
return ACL_FLOAT;
3232
case GGML_TYPE_F16:
3333
return ACL_FLOAT16;
34+
case GGML_TYPE_BF16:
35+
return ACL_BF16;
3436
case GGML_TYPE_I8:
3537
return ACL_INT8;
3638
case GGML_TYPE_I16:

ggml/src/ggml-cann/acl_tensor.h

100644100755
File mode changed.

0 commit comments

Comments
 (0)