Skip to content

CUDA: MMQ support for iq4_nl, iq4_xs #8278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 5, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

This PR adds MMQ support for iq4_nl and iq4_xs. The data is loaded, converted to 8 bit, and written to shared memory. Because this is the same strategy as for q5_0 the same code can be re-used except for the part that loads the data.

The other iq data types have issues with shared memory limits with the current MMQ code; it will need a refactor that allows setting the tile size in k direction as a configurable parameter. Presumably due to Ampere/Ada Lovelace consumer cards having 50% more shared memory than Turing this will also mean that for optimal performance there would then need to be different template instances for Turing and Ampere+. For A100s/H100s which have even more shared memory you would in principle also need different configurations but I am not interested in working on that hardware since I will not be able to afford it anyways.

Since I am already working on the MMQ code I replaced instances of get_int_from_int8 with the refactored and implified variants that accept void pointers.

Performance
Model GPU Microbatch size Test t/s master t/s cuda-iq-mmq-2 Speedup
llama 8B IQ4_NL - 4.5 bpw RX 6800 16 pp2048 67.10 236.65 3.53
llama 8B IQ4_NL - 4.5 bpw RX 6800 32 pp2048 120.16 328.50 2.73
llama 8B IQ4_NL - 4.5 bpw RX 6800 63 pp2048 183.65 396.08 2.16
llama 8B IQ4_NL - 4.5 bpw RX 6800 128 pp2048 337.26 509.96 1.51
llama 8B IQ4_NL - 4.5 bpw RX 6800 256 pp2048 519.73 606.06 1.17
llama 8B IQ4_NL - 4.5 bpw RX 6800 512 pp2048 577.88 611.98 1.06
llama 8B IQ4_NL - 4.5 bpw RX 6800 1024 pp2048 594.07 687.13 1.16
llama 8B IQ4_NL - 4.5 bpw RX 6800 2048 pp2048 581.07 625.90 1.08
llama 8B IQ4_NL - 4.5 bpw RTX 3090 16 pp2048 342.89 1046.94 3.05
llama 8B IQ4_NL - 4.5 bpw RTX 3090 32 pp2048 646.45 1713.55 2.65
llama 8B IQ4_NL - 4.5 bpw RTX 3090 63 pp2048 1211.66 2552.69 2.11
llama 8B IQ4_NL - 4.5 bpw RTX 3090 128 pp2048 2138.26 3225.39 1.51
llama 8B IQ4_NL - 4.5 bpw RTX 3090 256 pp2048 3149.48 3621.78 1.15
llama 8B IQ4_NL - 4.5 bpw RTX 3090 512 pp2048 3718.51 3762.09 1.01
llama 8B IQ4_NL - 4.5 bpw RTX 3090 1024 pp2048 4357.64 3736.55 0.86
llama 8B IQ4_NL - 4.5 bpw RTX 3090 2048 pp2048 4376.58 3621.46 0.83
llama 8B IQ4_NL - 4.5 bpw RTX 4090 16 pp2048 491.99 1950.90 3.97
llama 8B IQ4_NL - 4.5 bpw RTX 4090 32 pp2048 970.65 3398.93 3.50
llama 8B IQ4_NL - 4.5 bpw RTX 4090 63 pp2048 1754.37 5423.63 3.09
llama 8B IQ4_NL - 4.5 bpw RTX 4090 128 pp2048 3432.75 7492.11 2.18
llama 8B IQ4_NL - 4.5 bpw RTX 4090 256 pp2048 5706.46 9627.65 1.69
llama 8B IQ4_NL - 4.5 bpw RTX 4090 512 pp2048 7641.30 10338.84 1.35
llama 8B IQ4_NL - 4.5 bpw RTX 4090 1024 pp2048 8857.21 10008.78 1.13
llama 8B IQ4_NL - 4.5 bpw RTX 4090 2048 pp2048 8729.23 9000.75 1.03
llama 8B IQ4_NL - 4.5 bpw P40 16 pp2048 48.57 249.63 5.14
llama 8B IQ4_NL - 4.5 bpw P40 32 pp2048 95.67 418.39 4.37
llama 8B IQ4_NL - 4.5 bpw P40 63 pp2048 167.29 557.72 3.33
llama 8B IQ4_NL - 4.5 bpw P40 128 pp2048 212.61 681.91 3.21
llama 8B IQ4_NL - 4.5 bpw P40 256 pp2048 324.19 770.01 2.38
llama 8B IQ4_NL - 4.5 bpw P40 512 pp2048 446.36 807.85 1.81
llama 8B IQ4_NL - 4.5 bpw P40 1024 pp2048 511.27 797.13 1.56
llama 8B IQ4_NL - 4.5 bpw P40 2048 pp2048 523.17 765.75 1.46
llama 8B IQ4_XS - 4.25 bpw RX 6800 16 pp2048 67.23 236.07 3.51
llama 8B IQ4_XS - 4.25 bpw RX 6800 32 pp2048 120.34 331.87 2.76
llama 8B IQ4_XS - 4.25 bpw RX 6800 63 pp2048 183.93 397.62 2.16
llama 8B IQ4_XS - 4.25 bpw RX 6800 128 pp2048 338.18 512.38 1.52
llama 8B IQ4_XS - 4.25 bpw RX 6800 256 pp2048 520.16 610.81 1.17
llama 8B IQ4_XS - 4.25 bpw RX 6800 512 pp2048 578.30 615.09 1.06
llama 8B IQ4_XS - 4.25 bpw RX 6800 1024 pp2048 596.10 692.03 1.16
llama 8B IQ4_XS - 4.25 bpw RX 6800 2048 pp2048 580.27 630.13 1.09
llama 8B IQ4_XS - 4.25 bpw RTX 3090 16 pp2048 345.80 1051.77 3.04
llama 8B IQ4_XS - 4.25 bpw RTX 3090 32 pp2048 654.85 1701.79 2.60
llama 8B IQ4_XS - 4.25 bpw RTX 3090 63 pp2048 1227.19 2525.81 2.06
llama 8B IQ4_XS - 4.25 bpw RTX 3090 128 pp2048 2153.51 3166.41 1.47
llama 8B IQ4_XS - 4.25 bpw RTX 3090 256 pp2048 3150.56 3497.60 1.11
llama 8B IQ4_XS - 4.25 bpw RTX 3090 512 pp2048 3714.67 3629.68 0.98
llama 8B IQ4_XS - 4.25 bpw RTX 3090 1024 pp2048 4279.30 3654.98 0.85
llama 8B IQ4_XS - 4.25 bpw RTX 3090 2048 pp2048 4264.30 3567.27 0.84
llama 8B IQ4_XS - 4.25 bpw RTX 4090 16 pp2048 493.56 1970.30 3.99
llama 8B IQ4_XS - 4.25 bpw RTX 4090 32 pp2048 976.59 3417.54 3.50
llama 8B IQ4_XS - 4.25 bpw RTX 4090 63 pp2048 1766.28 5407.20 3.06
llama 8B IQ4_XS - 4.25 bpw RTX 4090 128 pp2048 3471.43 7390.92 2.13
llama 8B IQ4_XS - 4.25 bpw RTX 4090 256 pp2048 5724.22 9443.06 1.65
llama 8B IQ4_XS - 4.25 bpw RTX 4090 512 pp2048 7643.08 10119.12 1.32
llama 8B IQ4_XS - 4.25 bpw RTX 4090 1024 pp2048 8883.77 9814.66 1.10
llama 8B IQ4_XS - 4.25 bpw RTX 4090 2048 pp2048 8770.67 8813.15 1.00
llama 8B IQ4_XS - 4.25 bpw P40 16 pp2048 48.33 258.02 5.34
llama 8B IQ4_XS - 4.25 bpw P40 32 pp2048 95.25 418.44 4.39
llama 8B IQ4_XS - 4.25 bpw P40 63 pp2048 166.63 567.39 3.41
llama 8B IQ4_XS - 4.25 bpw P40 128 pp2048 212.48 687.74 3.24
llama 8B IQ4_XS - 4.25 bpw P40 256 pp2048 324.41 772.64 2.38
llama 8B IQ4_XS - 4.25 bpw P40 512 pp2048 451.42 812.53 1.80
llama 8B IQ4_XS - 4.25 bpw P40 1024 pp2048 512.68 802.59 1.57
llama 8B IQ4_XS - 4.25 bpw P40 2048 pp2048 524.43 771.65 1.47

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs python python script changes labels Jul 3, 2024
@mofosyne mofosyne added the Review Complexity : High Generally require indepth knowledge of LLMs or GPUs label Jul 3, 2024
@ExtReMLapin
Copy link
Contributor

Funny to see that on a rtx 4090, a higher microbatch size doesn't mean higher speed

@JohannesGaessler
Copy link
Collaborator Author

This has nothing to do with MMQ though; the MMQ runtime still goes down by ~3% if you increase the batch size from 512 to 2048. The problem is instead inefficient masking in the FlashAttention kernel where larger batch sizes lead to iteration over more values that are masked out anyways. cuBLAS has the same problem but in return suffers less from dequantization overhead at large batch sizes.

@JohannesGaessler JohannesGaessler merged commit 8e55830 into ggml-org:master Jul 5, 2024
53 checks passed
@nitinrathi
Copy link

nitinrathi commented Jul 5, 2024

@JohannesGaessler llama-bench is broken now

./llama-bench -m /media/nitin/soliddisk/Downloads/DeepSeek-Coder-V2-Lite-Instruct-IQ4_XS.gguf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
| model                          |       size |     params | backend    | ngl |          test |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | ---------------: |
GGML_ASSERT: ggml/src/ggml-cuda/mmq.cuh:2189: false
Could not attach to process.  If your uid matches the uid of the target
process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try
again as the root user.  For more details, see /etc/sysctl.d/10-ptrace.conf
ptrace: Operation not permitted.
No stack.
The program is not being run.
[1]    1035704 IOT instruction (core dumped)  ./llama-bench -m 

@nitinrathi
Copy link

nitinrathi commented Jul 5, 2024

@JohannesGaessler this broke the server as well

./llama-server --port 4786 --host "0.0.0.0" -m /media/nitin/soliddisk/Downloads/DeepSeek-Coder-V2-Lite-Instruct-IQ4_XS.gguf -ngl 99 -c 6000
INFO [                    main] build info | tid="123489702731776" timestamp=1720165430 build=3319 commit="148ec970"
INFO [                    main] system info | tid="123489702731776" timestamp=1720165430 n_threads=6 n_threads_batch=-1 total_threads=12 system_info="AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | "
llama_model_loader: loaded meta data with 42 key-value pairs and 377 tensors from /media/nitin/soliddisk/Downloads/DeepSeek-Coder-V2-Lite-Instruct-IQ4_XS.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = deepseek2
llama_model_loader: - kv   1:                               general.name str              = DeepSeek-Coder-V2-Lite-Instruct
llama_model_loader: - kv   2:                      deepseek2.block_count u32              = 27
llama_model_loader: - kv   3:                   deepseek2.context_length u32              = 163840
llama_model_loader: - kv   4:                 deepseek2.embedding_length u32              = 2048
llama_model_loader: - kv   5:              deepseek2.feed_forward_length u32              = 10944
llama_model_loader: - kv   6:             deepseek2.attention.head_count u32              = 16
llama_model_loader: - kv   7:          deepseek2.attention.head_count_kv u32              = 16
llama_model_loader: - kv   8:                   deepseek2.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv   9: deepseek2.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  10:                deepseek2.expert_used_count u32              = 6
llama_model_loader: - kv  11:                          general.file_type u32              = 30
llama_model_loader: - kv  12:        deepseek2.leading_dense_block_count u32              = 1
llama_model_loader: - kv  13:                       deepseek2.vocab_size u32              = 102400
llama_model_loader: - kv  14:           deepseek2.attention.kv_lora_rank u32              = 512
llama_model_loader: - kv  15:             deepseek2.attention.key_length u32              = 192
llama_model_loader: - kv  16:           deepseek2.attention.value_length u32              = 128
llama_model_loader: - kv  17:       deepseek2.expert_feed_forward_length u32              = 1408
llama_model_loader: - kv  18:                     deepseek2.expert_count u32              = 64
llama_model_loader: - kv  19:              deepseek2.expert_shared_count u32              = 2
llama_model_loader: - kv  20:             deepseek2.expert_weights_scale f32              = 1.000000
llama_model_loader: - kv  21:             deepseek2.rope.dimension_count u32              = 64
llama_model_loader: - kv  22:                deepseek2.rope.scaling.type str              = yarn
llama_model_loader: - kv  23:              deepseek2.rope.scaling.factor f32              = 40.000000
llama_model_loader: - kv  24: deepseek2.rope.scaling.original_context_length u32              = 4096
llama_model_loader: - kv  25: deepseek2.rope.scaling.yarn_log_multiplier f32              = 0.070700
llama_model_loader: - kv  26:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  27:                         tokenizer.ggml.pre str              = deepseek-llm
llama_model_loader: - kv  28:                      tokenizer.ggml.tokens arr[str,102400]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  29:                  tokenizer.ggml.token_type arr[i32,102400]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  30:                      tokenizer.ggml.merges arr[str,99757]   = ["Ġ Ġ", "Ġ t", "Ġ a", "i n", "h e...
llama_model_loader: - kv  31:                tokenizer.ggml.bos_token_id u32              = 100000
llama_model_loader: - kv  32:                tokenizer.ggml.eos_token_id u32              = 100001
llama_model_loader: - kv  33:            tokenizer.ggml.padding_token_id u32              = 100001
llama_model_loader: - kv  34:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  35:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  36:                    tokenizer.chat_template str              = {% if not add_generation_prompt is de...
llama_model_loader: - kv  37:               general.quantization_version u32              = 2
llama_model_loader: - kv  38:                      quantize.imatrix.file str              = /models/DeepSeek-Coder-V2-Lite-Instru...
llama_model_loader: - kv  39:                   quantize.imatrix.dataset str              = /training_data/calibration_datav3.txt
llama_model_loader: - kv  40:             quantize.imatrix.entries_count i32              = 293
llama_model_loader: - kv  41:              quantize.imatrix.chunks_count i32              = 139
llama_model_loader: - type  f32:  108 tensors
llama_model_loader: - type q6_K:    1 tensors
llama_model_loader: - type iq4_nl:   27 tensors
llama_model_loader: - type iq4_xs:  241 tensors
llm_load_vocab: special tokens cache size = 2400
llm_load_vocab: token to piece cache size = 0.6661 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = deepseek2
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 102400
llm_load_print_meta: n_merges         = 99757
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 163840
llm_load_print_meta: n_embd           = 2048
llm_load_print_meta: n_layer          = 27
llm_load_print_meta: n_head           = 16
llm_load_print_meta: n_head_kv        = 16
llm_load_print_meta: n_rot            = 64
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 192
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 3072
llm_load_print_meta: n_embd_v_gqa     = 2048
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 10944
llm_load_print_meta: n_expert         = 64
llm_load_print_meta: n_expert_used    = 6
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = yarn
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 0.025
llm_load_print_meta: n_ctx_orig_yarn  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = 16B
llm_load_print_meta: model ftype      = IQ4_XS - 4.25 bpw
llm_load_print_meta: model params     = 15.71 B
llm_load_print_meta: model size       = 7.98 GiB (4.36 BPW) 
llm_load_print_meta: general.name     = DeepSeek-Coder-V2-Lite-Instruct
llm_load_print_meta: BOS token        = 100000 '<|begin▁of▁sentence|>'
llm_load_print_meta: EOS token        = 100001 '<|end▁of▁sentence|>'
llm_load_print_meta: PAD token        = 100001 '<|end▁of▁sentence|>'
llm_load_print_meta: LF token         = 126 'Ä'
llm_load_print_meta: max token length = 256
llm_load_print_meta: n_layer_dense_lead   = 1
llm_load_print_meta: n_lora_q             = 0
llm_load_print_meta: n_lora_kv            = 512
llm_load_print_meta: n_ff_exp             = 1408
llm_load_print_meta: n_expert_shared      = 2
llm_load_print_meta: expert_weights_scale = 1.0
llm_load_print_meta: rope_yarn_log_mul    = 0.0707
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
llm_load_tensors: ggml ctx size =    0.32 MiB
llm_load_tensors: offloading 27 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 28/28 layers to GPU
llm_load_tensors:        CPU buffer size =   106.25 MiB
llm_load_tensors:      CUDA0 buffer size =  8064.46 MiB
......................................................................................
llama_new_context_with_model: n_ctx      = 6016
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 0.025
llama_kv_cache_init:      CUDA0 KV buffer size =  1586.25 MiB
llama_new_context_with_model: KV self size  = 1586.25 MiB, K (f16):  951.75 MiB, V (f16):  634.50 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.78 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =   223.75 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    15.76 MiB
llama_new_context_with_model: graph nodes  = 1924
llama_new_context_with_model: graph splits = 2
INFO [                    init] initializing slots | tid="123489702731776" timestamp=1720165432 n_slots=1
INFO [                    init] new slot | tid="123489702731776" timestamp=1720165432 id_slot=0 n_ctx_slot=6016
INFO [                    main] model loaded | tid="123489702731776" timestamp=1720165432
INFO [                    main] chat template | tid="123489702731776" timestamp=1720165432 chat_example="You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: How are you?\n\nAssistant:" built_in=true
INFO [                    main] HTTP server listening | tid="123489702731776" timestamp=1720165432 n_threads_http="11" port="4786" hostname="0.0.0.0"
INFO [            update_slots] all slots are idle | tid="123489702731776" timestamp=1720165432
INFO [   launch_slot_with_task] slot is processing task | tid="123489702731776" timestamp=1720165493 id_slot=0 id_task=0
INFO [            update_slots] kv cache rm [p0, end) | tid="123489702731776" timestamp=1720165493 id_slot=0 id_task=0 p0=0
GGML_ASSERT: ggml/src/ggml-cuda/mmq.cuh:2189: false
Could not attach to process.  If your uid matches the uid of the target
process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try
again as the root user.  For more details, see /etc/sysctl.d/10-ptrace.conf
ptrace: Operation not permitted.
No stack.
The program is not being run.
[1]    1033287 IOT instruction (core dumped)  ./llama-server --port 4786 --host "0.0.0.0" -m  -ngl 99 -c 6000

@Green-Sky
Copy link
Collaborator

@nitinrathi check #8311

@JohannesGaessler
Copy link
Collaborator Author

What commit are you on? Line 2189 of mmq.cuh does not have an assert on master so you cannot be using that. Please report whether it works with the latest master commit.

@nitinrathi
Copy link

nitinrathi commented Jul 5, 2024

INFO [                    main] build info | tid="123489702731776" timestamp=1720165430 build=3319 commit="148ec970"

master is right now on 148ec97
It doesn't work again the latest master commit.

@JohannesGaessler
Copy link
Collaborator Author

Okay, but as I said: there is no GGML_ASSERT on line 2189. That is the line where there was an assert prior to this PR. So I think there is something wrong with your local install in particular where maybe for some reason old files are incorrectly used. What happens when you compile with LLAMA_NO_CCACHE?

@JohannesGaessler
Copy link
Collaborator Author

Actually before you try LLAMA_NO_CCACHE, try make clean if you haven't already.

@nitinrathi
Copy link

@JohannesGaessler My apologies, I am very sorry. Everything works fine after compiling with LLAMA_NO_CACHE.

@nitinrathi
Copy link

Thank you for performance improvements
BEFORE

./llama-bench -m /media/nitin/soliddisk/Downloads/DeepSeek-Coder-V2-Lite-Instruct-IQ4_XS.gguf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
| model                          |       size |     params | backend    | ngl |          test |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | ---------------: |
| deepseek2 16B IQ4_XS - 4.25 bpw |   7.98 GiB |    15.71 B | CUDA       |  99 |         pp512 |   1023.28 ± 4.99 |
| deepseek2 16B IQ4_XS - 4.25 bpw |   7.98 GiB |    15.71 B | CUDA       |  99 |         tg128 |     68.83 ± 1.40 |

AFTER

./llama-bench -m /media/nitin/soliddisk/Downloads/DeepSeek-Coder-V2-Lite-Instruct-IQ4_XS.gguf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes
| model                          |       size |     params | backend    | ngl |          test |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | ---------------: |
| deepseek2 16B IQ4_XS - 4.25 bpw |   7.98 GiB |    15.71 B | CUDA       |  99 |         pp512 |   1261.26 ± 4.13 |
| deepseek2 16B IQ4_XS - 4.25 bpw |   7.98 GiB |    15.71 B | CUDA       |  99 |         tg128 |     68.53 ± 1.86 |

@nitinrathi
Copy link

@Green-Sky #8311 also works great.

arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jul 13, 2024
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jul 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Nvidia GPU Issues specific to Nvidia GPUs python python script changes Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants