Skip to content

Commit cfe866e

Browse files
committed
Merge branch 'master' into pr/8836
2 parents fddff02 + fc54ef0 commit cfe866e

26 files changed

+546
-300
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ Typically finetunes of the base models below are supported as well.
106106
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
107107
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
108108
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
109+
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
109110

110111
(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))
111112

common/common.cpp

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,41 @@
7777

7878
using json = nlohmann::ordered_json;
7979

80+
//
81+
// Environment variable utils
82+
//
83+
84+
template<typename T>
85+
static typename std::enable_if<std::is_same<T, std::string>::value, void>::type
86+
get_env(std::string name, T & target) {
87+
char * value = std::getenv(name.c_str());
88+
target = value ? std::string(value) : target;
89+
}
90+
91+
template<typename T>
92+
static typename std::enable_if<!std::is_same<T, bool>::value && std::is_integral<T>::value, void>::type
93+
get_env(std::string name, T & target) {
94+
char * value = std::getenv(name.c_str());
95+
target = value ? std::stoi(value) : target;
96+
}
97+
98+
template<typename T>
99+
static typename std::enable_if<std::is_floating_point<T>::value, void>::type
100+
get_env(std::string name, T & target) {
101+
char * value = std::getenv(name.c_str());
102+
target = value ? std::stof(value) : target;
103+
}
104+
105+
template<typename T>
106+
static typename std::enable_if<std::is_same<T, bool>::value, void>::type
107+
get_env(std::string name, T & target) {
108+
char * value = std::getenv(name.c_str());
109+
if (value) {
110+
std::string val(value);
111+
target = val == "1" || val == "true";
112+
}
113+
}
114+
80115
//
81116
// CPU utils
82117
//
@@ -220,12 +255,6 @@ int32_t cpu_get_num_math() {
220255
// CLI argument parsing
221256
//
222257

223-
void gpt_params_handle_hf_token(gpt_params & params) {
224-
if (params.hf_token.empty() && std::getenv("HF_TOKEN")) {
225-
params.hf_token = std::getenv("HF_TOKEN");
226-
}
227-
}
228-
229258
void gpt_params_handle_model_default(gpt_params & params) {
230259
if (!params.hf_repo.empty()) {
231260
// short-hand to avoid specifying --hf-file -> default it to --model
@@ -273,7 +302,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
273302

274303
gpt_params_handle_model_default(params);
275304

276-
gpt_params_handle_hf_token(params);
305+
if (params.hf_token.empty()) {
306+
get_env("HF_TOKEN", params.hf_token);
307+
}
277308

278309
if (params.escape) {
279310
string_process_escapes(params.prompt);
@@ -293,6 +324,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
293324
return true;
294325
}
295326

327+
void gpt_params_parse_from_env(gpt_params & params) {
328+
// we only care about server-related params for now
329+
get_env("LLAMA_ARG_MODEL", params.model);
330+
get_env("LLAMA_ARG_THREADS", params.n_threads);
331+
get_env("LLAMA_ARG_CTX_SIZE", params.n_ctx);
332+
get_env("LLAMA_ARG_N_PARALLEL", params.n_parallel);
333+
get_env("LLAMA_ARG_BATCH", params.n_batch);
334+
get_env("LLAMA_ARG_UBATCH", params.n_ubatch);
335+
get_env("LLAMA_ARG_N_GPU_LAYERS", params.n_gpu_layers);
336+
get_env("LLAMA_ARG_THREADS_HTTP", params.n_threads_http);
337+
get_env("LLAMA_ARG_CHAT_TEMPLATE", params.chat_template);
338+
get_env("LLAMA_ARG_N_PREDICT", params.n_predict);
339+
get_env("LLAMA_ARG_ENDPOINT_METRICS", params.endpoint_metrics);
340+
get_env("LLAMA_ARG_ENDPOINT_SLOTS", params.endpoint_slots);
341+
get_env("LLAMA_ARG_EMBEDDINGS", params.embedding);
342+
get_env("LLAMA_ARG_FLASH_ATTN", params.flash_attn);
343+
get_env("LLAMA_ARG_DEFRAG_THOLD", params.defrag_thold);
344+
}
345+
296346
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
297347
const auto params_org = params; // the example can modify the default params
298348

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ struct gpt_params {
267267
std::string lora_outfile = "ggml-lora-merged-f16.gguf";
268268
};
269269

270-
void gpt_params_handle_hf_token(gpt_params & params);
270+
void gpt_params_parse_from_env(gpt_params & params);
271271
void gpt_params_handle_model_default(gpt_params & params);
272272

273273
bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params);

convert_hf_to_gguf.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def prepare_tensors(self):
295295
gguf.MODEL_TENSOR.FFN_GATE_INP,
296296
gguf.MODEL_TENSOR.POS_EMBD,
297297
gguf.MODEL_TENSOR.TOKEN_TYPES,
298+
gguf.MODEL_TENSOR.SSM_CONV1D,
298299
)
299300
)
300301
or not name.endswith(".weight")
@@ -2711,7 +2712,7 @@ class StarCoder2Model(Model):
27112712
model_arch = gguf.MODEL_ARCH.STARCODER2
27122713

27132714

2714-
@Model.register("MambaForCausalLM", "MambaLMHeadModel")
2715+
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
27152716
class MambaModel(Model):
27162717
model_arch = gguf.MODEL_ARCH.MAMBA
27172718

@@ -2742,20 +2743,24 @@ def set_gguf_parameters(self):
27422743
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
27432744
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
27442745
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
2745-
2746+
use_dt_b_c_norm = False
2747+
# For falconmamba we do apply RMS norm on B / DT and C layers
2748+
if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",):
2749+
use_dt_b_c_norm = True
27462750
# Fail early for models which don't have a block expansion factor of 2
27472751
assert d_inner == 2 * d_model
27482752

27492753
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
27502754
self.gguf_writer.add_embedding_length(d_model)
27512755
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
27522756
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
2753-
self.gguf_writer.add_block_count(self.hparams["n_layer"])
2757+
self.gguf_writer.add_block_count(self.block_count)
27542758
self.gguf_writer.add_ssm_conv_kernel(d_conv)
27552759
self.gguf_writer.add_ssm_inner_size(d_inner)
27562760
self.gguf_writer.add_ssm_state_size(d_state)
27572761
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
27582762
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
2763+
self.gguf_writer.add_ssm_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers
27592764
self.gguf_writer.add_file_type(self.ftype)
27602765

27612766
_tok_embd = None
@@ -2782,23 +2787,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27822787

27832788
return [(new_name, data_torch)]
27842789

2785-
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
2786-
if bid is not None and new_name in (
2787-
self.format_tensor_name(
2788-
n, bid, ".weight" if name.endswith(".weight") else ""
2789-
)
2790-
for n in [
2791-
gguf.MODEL_TENSOR.SSM_CONV1D,
2792-
gguf.MODEL_TENSOR.SSM_X,
2793-
gguf.MODEL_TENSOR.SSM_DT,
2794-
gguf.MODEL_TENSOR.SSM_A,
2795-
gguf.MODEL_TENSOR.SSM_D,
2796-
]
2797-
):
2798-
return gguf.GGMLQuantizationType.F32
2799-
2800-
return super().tensor_force_quant(name, new_name, bid, n_dims)
2801-
28022790

28032791
@Model.register("CohereForCausalLM")
28042792
class CommandR2Model(Model):
@@ -3792,7 +3780,7 @@ class ExaoneModel(Model):
37923780
def set_gguf_parameters(self):
37933781
hparams = self.hparams
37943782

3795-
assert(hparams["activation_function"] == "silu")
3783+
assert (hparams["activation_function"] == "silu")
37963784

37973785
max_position_embeddings = hparams["max_position_embeddings"]
37983786
embed_dim = hparams["hidden_size"]
@@ -3855,8 +3843,8 @@ def prepare_tensors(self):
38553843

38563844
super().prepare_tensors()
38573845

3858-
###### CONVERSION LOGIC ######
38593846

3847+
###### CONVERSION LOGIC ######
38603848

38613849
# tree of lazy tensors
38623850
class LazyTorchTensor(gguf.LazyBase):

examples/llava/clip.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
#include "ggml-cann.h"
2121
#endif
2222

23+
#ifdef GGML_USE_VULKAN
24+
#include "ggml-vulkan.h"
25+
#endif
26+
2327
#define STB_IMAGE_IMPLEMENTATION
2428
#include "stb_image.h"
2529

@@ -1108,7 +1112,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
11081112
}
11091113
}
11101114

1111-
clip_ctx * new_clip = new clip_ctx;
1115+
clip_ctx * new_clip = new clip_ctx{};
11121116

11131117
// update projector type
11141118
{
@@ -1142,6 +1146,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
11421146
LOG_TEE("%s: CLIP using CANN backend\n", __func__);
11431147
#endif
11441148

1149+
#ifdef GGML_USE_VULKAN
1150+
new_clip->backend = ggml_backend_vk_init(0);
1151+
LOG_TEE("%s: CLIP using Vulkan backend\n", __func__);
1152+
#endif
11451153

11461154
if (!new_clip->backend) {
11471155
new_clip->backend = ggml_backend_cpu_init();

examples/server/README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,25 @@ logging:
247247
--log-append Don't truncate the old log file.
248248
```
249249

250+
Available environment variables (if specified, these variables will override parameters specified in arguments):
251+
252+
- `LLAMA_CACHE` (cache directory, used by `--hf-repo`)
253+
- `HF_TOKEN` (Hugging Face access token, used when accessing a gated model with `--hf-repo`)
254+
- `LLAMA_ARG_MODEL`
255+
- `LLAMA_ARG_THREADS`
256+
- `LLAMA_ARG_CTX_SIZE`
257+
- `LLAMA_ARG_N_PARALLEL`
258+
- `LLAMA_ARG_BATCH`
259+
- `LLAMA_ARG_UBATCH`
260+
- `LLAMA_ARG_N_GPU_LAYERS`
261+
- `LLAMA_ARG_THREADS_HTTP`
262+
- `LLAMA_ARG_CHAT_TEMPLATE`
263+
- `LLAMA_ARG_N_PREDICT`
264+
- `LLAMA_ARG_ENDPOINT_METRICS`
265+
- `LLAMA_ARG_ENDPOINT_SLOTS`
266+
- `LLAMA_ARG_EMBEDDINGS`
267+
- `LLAMA_ARG_FLASH_ATTN`
268+
- `LLAMA_ARG_DEFRAG_THOLD`
250269

251270
## Build
252271

examples/server/server.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2507,6 +2507,9 @@ int main(int argc, char ** argv) {
25072507
return 1;
25082508
}
25092509

2510+
// parse arguments from environment variables
2511+
gpt_params_parse_from_env(params);
2512+
25102513
// TODO: not great to use extern vars
25112514
server_log_json = params.log_json;
25122515
server_verbose = params.verbosity > 0;

ggml/src/ggml-sycl.cpp

Lines changed: 2 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -893,43 +893,6 @@ static void clamp_f32(const float * x, float * dst, const float min, const float
893893
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
894894
}
895895

896-
template <typename T>
897-
static void im2col_kernel(const float *x, T *dst, int offset_delta,
898-
int IW, int IH, int OW, int KW, int KH,
899-
int pelements, int CHW, int s0, int s1, int p0,
900-
int p1, int d0, int d1,
901-
const sycl::nd_item<3> &item_ct1) {
902-
const int i = item_ct1.get_local_id(2) +
903-
item_ct1.get_group(2) * item_ct1.get_local_range(2);
904-
if (i >= pelements) {
905-
return;
906-
}
907-
908-
const int ksize = OW * (KH > 1 ? KW : 1);
909-
const int kx = i / ksize;
910-
const int kd = kx * ksize;
911-
const int ky = (i - kd) / OW;
912-
const int ix = i % OW;
913-
914-
const int64_t iiw = ix * s0 + kx * d0 - p0;
915-
const int64_t iih = item_ct1.get_group(1) * s1 + ky * d1 - p1;
916-
917-
const int64_t offset_dst =
918-
(item_ct1.get_group(1) * OW + ix) * CHW +
919-
(item_ct1.get_group(0) * (KW * KH) + ky * KW + kx);
920-
921-
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
922-
dst[offset_dst] =
923-
sycl::vec<float, 1>(0.0f)
924-
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
925-
} else {
926-
const int64_t offset_src = item_ct1.get_group(0) * offset_delta;
927-
dst[offset_dst] =
928-
sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
929-
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
930-
}
931-
}
932-
933896
template <typename Ti, typename To>
934897
static void pool2d_nchw_kernel(
935898
const int ih, const int iw, const int oh, const int ow,
@@ -1742,32 +1705,6 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
17421705
});
17431706
}
17441707

1745-
template <typename T>
1746-
static void im2col_sycl(const float *x, T *dst, int IW, int IH,
1747-
int OW, int OH, int KW, int KH, int IC,
1748-
int offset_delta, int s0, int s1, int p0,
1749-
int p1, int d0, int d1,
1750-
queue_ptr stream) {
1751-
const int parallel_elements = OW * KW * KH;
1752-
const int num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
1753-
sycl::range<3> block_nums(IC, OH, num_blocks);
1754-
{
1755-
dpct::has_capability_or_fail(stream->get_device(),
1756-
{sycl::aspect::fp16});
1757-
1758-
stream->parallel_for(
1759-
sycl::nd_range<3>(block_nums *
1760-
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
1761-
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
1762-
[=](sycl::nd_item<3> item_ct1) {
1763-
im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH,
1764-
parallel_elements, (IC * KH * KW), s0, s1, p0,
1765-
p1, d0, d1, item_ct1);
1766-
});
1767-
}
1768-
}
1769-
1770-
17711708
static bool g_sycl_loaded = false;
17721709

17731710
bool ggml_sycl_loaded(void) {
@@ -2636,47 +2573,6 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
26362573
(void) src1_dd;
26372574
}
26382575

2639-
inline void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2640-
const ggml_tensor *src1, ggml_tensor *dst,
2641-
const float *src0_dd, const float *src1_dd,
2642-
float *dst_dd,
2643-
const queue_ptr &main_stream) {
2644-
2645-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
2646-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
2647-
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
2648-
2649-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
2650-
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
2651-
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
2652-
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
2653-
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
2654-
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
2655-
2656-
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
2657-
2658-
const int64_t IC = src1->ne[is_2D ? 2 : 1];
2659-
const int64_t IH = is_2D ? src1->ne[1] : 1;
2660-
const int64_t IW = src1->ne[0];
2661-
2662-
const int64_t KH = is_2D ? src0->ne[1] : 1;
2663-
const int64_t KW = src0->ne[0];
2664-
2665-
const int64_t OH = is_2D ? dst->ne[2] : 1;
2666-
const int64_t OW = dst->ne[1];
2667-
2668-
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
2669-
2670-
if (dst->type == GGML_TYPE_F16) {
2671-
im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
2672-
} else {
2673-
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
2674-
}
2675-
2676-
(void) src0;
2677-
(void) src0_dd;
2678-
}
2679-
26802576
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
26812577
const ggml_tensor *src1, ggml_tensor *dst,
26822578
const float *src0_dd, const float *src1_dd,
@@ -3581,7 +3477,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
35813477

35823478
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
35833479
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
3584-
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3480+
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE
3481+
&& (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda || src1->ne[1] > MMVQ_MIN_BATCH_SIZE);
35853482

35863483
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
35873484
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;

0 commit comments

Comments
 (0)