Skip to content

Commit 5920fe5

Browse files
committed
add fp16 support for qwen2vl and m-rope
1 parent 9f28932 commit 5920fe5

File tree

3 files changed

+103
-36
lines changed

3 files changed

+103
-36
lines changed

examples/llava/qwen2_vl_surgery.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
3737
vision_model = qwen2vl.visual
3838
tensor_map = {}
3939
for name, ten in vision_model.state_dict().items():
40-
ten = ten.numpy().astype(dtype)
40+
ten = ten.numpy()
4141
if 'qkv' in name:
4242
if ten.ndim == 2: # weight
4343
c3, _ = ten.shape
@@ -68,18 +68,23 @@ def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
6868
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
6969
else:
7070
tensor_map[to_gguf_name(f"vision_model.{name}")] = ten
71-
72-
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=dtype) # dummy tensor, just here as a placeholder
71+
72+
for new_name, ten in tensor_map.items():
73+
if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
74+
tensor_map[new_name] = ten.astype(np.float32)
75+
else:
76+
tensor_map[new_name] = ten.astype(dtype)
77+
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
7378
return tensor_map
7479

7580

76-
def main(args, data_type='fp32'):
77-
if data_type == 'fp32':
81+
def main(args):
82+
if args.data_type == 'fp32':
7883
dtype = torch.float32
7984
np_dtype = np.float32
8085
ftype = 0
81-
elif data_type == 'fp16':
82-
dtype = torch.float16
86+
elif args.data_type == 'fp16':
87+
dtype = torch.float32
8388
np_dtype = np.float16
8489
ftype = 1
8590
else:
@@ -144,5 +149,6 @@ def main(args, data_type='fp32'):
144149
if __name__ == "__main__":
145150
parser = argparse.ArgumentParser()
146151
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
152+
parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
147153
args = parser.parse_args()
148154
main(args)

ggml/src/ggml-cuda/rope.cu

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,6 @@ static __global__ void rope_vision(
181181

182182
const int row = blockDim.x*blockIdx.x + threadIdx.x;
183183

184-
// if (i0 >= n_dims) {
185-
// const int i = row*ne0 + i0;
186-
187-
// dst[i + 0] = x[i + 0];
188-
// dst[i + 1] = x[i + 1];
189-
190-
// return;
191-
// }
192-
193184
const int i = row*ne0 + i0/2;
194185
const int i2 = row/p_delta_rows; // i2-th tokens
195186

@@ -348,6 +339,14 @@ static void rope_neox_cuda_f32(
348339
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
349340
}
350341

342+
static void rope_mrope_cuda_f16(
343+
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
344+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
345+
) {
346+
347+
rope_mrope_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
348+
}
349+
351350
static void rope_mrope_cuda_f32(
352351
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
353352
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
@@ -356,6 +355,14 @@ static void rope_mrope_cuda_f32(
356355
rope_mrope_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
357356
}
358357

358+
static void rope_vision_cuda_f16(
359+
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
360+
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
361+
) {
362+
363+
rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
364+
}
365+
359366
static void rope_vision_cuda_f32(
360367
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
361368
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
@@ -448,11 +455,11 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
448455
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
449456
attn_factor, corr_dims, freq_factors, sections, stream
450457
);
451-
} else if (src0->type == GGML_TYPE_F16 && false) {
452-
// rope_mrope_cuda_f16(
453-
// (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
454-
// attn_factor, corr_dims, freq_factors, stream
455-
// );
458+
} else if (src0->type == GGML_TYPE_F16) {
459+
rope_mrope_cuda_f16(
460+
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
461+
attn_factor, corr_dims, freq_factors, sections, stream
462+
);
456463
} else {
457464
GGML_ABORT("fatal error");
458465
}
@@ -462,11 +469,11 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
462469
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
463470
attn_factor, corr_dims, freq_factors, sections, stream
464471
);
465-
} else if (src0->type == GGML_TYPE_F16 && false) {
466-
// rope_vision_cuda_f16(
467-
// (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
468-
// attn_factor, corr_dims, freq_factors, stream
469-
// );
472+
} else if (src0->type == GGML_TYPE_F16) {
473+
rope_vision_cuda_f16(
474+
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
475+
attn_factor, corr_dims, freq_factors, sections, stream
476+
);
470477
} else {
471478
GGML_ABORT("fatal error");
472479
}

ggml/src/ggml.c

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11343,7 +11343,7 @@ static void ggml_compute_forward_rope_f32(
1134311343
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1134411344
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1134511345
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
11346-
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
11346+
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
1134711347

1134811348
GGML_TENSOR_UNARY_OP_LOCALS
1134911349

@@ -11480,11 +11480,10 @@ static void ggml_compute_forward_rope_f32(
1148011480
const float x0 = src[0];
1148111481
const float x1 = src[n_dims];
1148211482

11483-
dst_data[0] = x0*cos_theta - x1*sin_theta;
11483+
dst_data[0] = x0*cos_theta - x1*sin_theta;
1148411484
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
1148511485
}
11486-
}
11487-
else {
11486+
} else {
1148811487
// fill the remain channels with data from src tensor
1148911488
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
1149011489
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -11510,6 +11509,7 @@ static void ggml_compute_forward_rope_f16(
1151011509
const struct ggml_tensor * src2 = dst->src[2];
1151111510

1151211511
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
11512+
int sections[4];
1151311513

1151411514
//const int n_past = ((int32_t *) dst->op_params)[0];
1151511515
const int n_dims = ((int32_t *) dst->op_params)[1];
@@ -11522,6 +11522,8 @@ static void ggml_compute_forward_rope_f16(
1152211522
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1152311523
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1152411524
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
11525+
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
11526+
1152511527

1152611528
GGML_TENSOR_UNARY_OP_LOCALS
1152711529

@@ -11554,6 +11556,12 @@ static void ggml_compute_forward_rope_f16(
1155411556
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1155511557

1155611558
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
11559+
const bool is_mrope = sections[0] > 0 || sections[1] > 0 || sections[2] > 0;
11560+
const bool is_vision = is_mrope && sections[3] > 0;
11561+
11562+
if (is_vision) {
11563+
GGML_ASSERT(n_dims == ne0/2);
11564+
}
1155711565

1155811566
const float * freq_factors = NULL;
1155911567
if (src2 != NULL) {
@@ -11574,7 +11582,19 @@ static void ggml_compute_forward_rope_f16(
1157411582
const int64_t p = pos[i2];
1157511583

1157611584
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
11577-
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
11585+
if (!is_mrope) {
11586+
const int64_t p = pos[i2];
11587+
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
11588+
}
11589+
else {
11590+
const int64_t p_t = pos[i2];
11591+
const int64_t p_h = pos[i2 + ne2];
11592+
const int64_t p_w = pos[i2 + ne2 * 2];
11593+
const int64_t p_e = pos[i2 + ne2 * 3];
11594+
ggml_mrope_cache_init(
11595+
p_t, p_h, p_w, p_e, sections, sections[3] != 0,
11596+
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
11597+
}
1157811598

1157911599
for (int64_t i1 = 0; i1 < ne1; i1++) {
1158011600
if (ir++ < ir0) continue;
@@ -11594,6 +11614,22 @@ static void ggml_compute_forward_rope_f16(
1159411614
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
1159511615
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
1159611616
}
11617+
} else if (is_vision){
11618+
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
11619+
const int64_t ic = i0/2;
11620+
11621+
const float cos_theta = cache[i0 + 0];
11622+
const float sin_theta = cache[i0 + 1];
11623+
11624+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
11625+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
11626+
11627+
const float x0 = GGML_FP16_TO_FP32(src[0]);
11628+
const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
11629+
11630+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11631+
dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11632+
}
1159711633
} else {
1159811634
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
1159911635
const int64_t ic = i0/2;
@@ -11612,12 +11648,30 @@ static void ggml_compute_forward_rope_f16(
1161211648
}
1161311649
}
1161411650

11615-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
11616-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11617-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11651+
if (is_vision) {
11652+
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
11653+
const int64_t ic = i0/2;
1161811654

11619-
dst_data[0] = src[0];
11620-
dst_data[1] = src[1];
11655+
const float cos_theta = cache[i0 + 0];
11656+
const float sin_theta = cache[i0 + 1];
11657+
11658+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
11659+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
11660+
11661+
const float x0 = GGML_FP16_TO_FP32(src[0]);
11662+
const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
11663+
11664+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
11665+
dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
11666+
}
11667+
} else {
11668+
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
11669+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11670+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11671+
11672+
dst_data[0] = src[0];
11673+
dst_data[1] = src[1];
11674+
}
1162111675
}
1162211676
}
1162311677
}

0 commit comments

Comments
 (0)