Skip to content

CUDA: revise q8_1 data layout for mul_mat_q #7824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 9, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

This PR changes the data layout for the FP32 values quantized to q8_1 in conjunction with MMQ. The blocks with 32 values are consolidated into larger blocks of 128 values in order to make them align to 16 bits and have the same data layout in global and shared memory. This is relevant for asynchronous data loading (Ampere or newer, not in this PR). The memory layout in shared memory now also allows for more efficient data loading for int8 tensor cores; for optimal performance you need an offset of 16 bytes between columns and it just so happens that for 128 values this is the exact size that the q8_1 scales take up.

The q8_1 memory layout in global memory for MMQ is that values are first consolidated into blocks of size 128 along ne10. These blocks are then treated as individual elements and transposed.

In ggml-cuda.cu I changed the argument for whether or not src1 should be quantized to a function pointer to the specific function that should be used. The data transfer for --split-mode row needs to be different. I added a utility function cudaMemcpy2DPeerAsync since it for whatever reason does not exist in the CUDA toolkit.

Performance on RTX4090/RTX3090 stays mostly the same, the changes for P40/RX 6800 are larger. The performance increases on average.

Specific numbers
GPU Model Microbatch size FlashAttention Test t/s master t/s cuda-mmq-q8_1-2 Speedup
RTX 4090 llama 8B Q2_K_M 16 Yes pp512 1102.82 1118.51 1.01
RTX 4090 llama 8B Q2_K_M 32 Yes pp512 1562.51 1584.70 1.01
RTX 4090 llama 8B Q2_K_M 64 Yes pp512 2022.17 2033.73 1.01
RTX 4090 llama 8B Q2_K_M 128 Yes pp512 3726.70 3728.89 1.00
RTX 4090 llama 8B Q2_K_M 256 Yes pp512 5989.97 5990.15 1.00
RTX 4090 llama 8B Q2_K_M 512 Yes pp512 7825.63 7816.57 1.00
RTX 4090 llama 8B Q3_K_S 16 Yes pp512 888.89 887.58 1.00
RTX 4090 llama 8B Q3_K_S 32 Yes pp512 1342.94 1367.00 1.02
RTX 4090 llama 8B Q3_K_S 64 Yes pp512 1878.46 1895.10 1.01
RTX 4090 llama 8B Q3_K_S 128 Yes pp512 3627.84 3626.18 1.00
RTX 4090 llama 8B Q3_K_S 256 Yes pp512 5881.27 5882.97 1.00
RTX 4090 llama 8B Q3_K_S 512 Yes pp512 7728.16 7744.08 1.00
RTX 4090 llama 8B Q4_0 16 Yes pp512 1598.48 1642.08 1.03
RTX 4090 llama 8B Q4_0 32 Yes pp512 2616.94 2644.69 1.01
RTX 4090 llama 8B Q4_0 64 Yes pp512 3693.61 3764.31 1.02
RTX 4090 llama 8B Q4_0 128 Yes pp512 3580.46 3577.67 1.00
RTX 4090 llama 8B Q4_0 256 Yes pp512 5838.55 5839.45 1.00
RTX 4090 llama 8B Q4_0 512 Yes pp512 7775.65 7807.98 1.00
RTX 4090 llama 8B Q4_1 16 Yes pp512 1560.23 1568.59 1.01
RTX 4090 llama 8B Q4_1 32 Yes pp512 2574.10 2563.09 1.00
RTX 4090 llama 8B Q4_1 64 Yes pp512 3670.49 3702.49 1.01
RTX 4090 llama 8B Q4_1 128 Yes pp512 3535.60 3538.59 1.00
RTX 4090 llama 8B Q4_1 256 Yes pp512 5785.55 5788.89 1.00
RTX 4090 llama 8B Q4_1 512 Yes pp512 7658.33 7668.97 1.00
RTX 4090 llama 8B Q4_K_S 16 Yes pp512 1392.65 1409.09 1.01
RTX 4090 llama 8B Q4_K_S 32 Yes pp512 2367.72 2405.85 1.02
RTX 4090 llama 8B Q4_K_S 64 Yes pp512 3131.95 3163.14 1.01
RTX 4090 llama 8B Q4_K_S 128 Yes pp512 3582.14 3583.80 1.00
RTX 4090 llama 8B Q4_K_S 256 Yes pp512 5850.85 5865.58 1.00
RTX 4090 llama 8B Q4_K_S 512 Yes pp512 7783.61 7789.94 1.00
RTX 4090 llama 8B Q5_0 16 Yes pp512 1257.68 1252.06 1.00
RTX 4090 llama 8B Q5_0 32 Yes pp512 1988.88 1984.19 1.00
RTX 4090 llama 8B Q5_0 64 Yes pp512 2919.13 2956.97 1.01
RTX 4090 llama 8B Q5_0 128 Yes pp512 3480.97 3479.59 1.00
RTX 4090 llama 8B Q5_0 256 Yes pp512 5729.44 5731.15 1.00
RTX 4090 llama 8B Q5_0 512 Yes pp512 7690.57 7696.39 1.00
RTX 4090 llama 8B Q5_1 16 Yes pp512 1303.03 1315.82 1.01
RTX 4090 llama 8B Q5_1 32 Yes pp512 2156.25 2044.06 0.95
RTX 4090 llama 8B Q5_1 64 Yes pp512 2991.30 2888.40 0.97
RTX 4090 llama 8B Q5_1 128 Yes pp512 3494.70 3490.87 1.00
RTX 4090 llama 8B Q5_1 256 Yes pp512 5709.80 5707.48 1.00
RTX 4090 llama 8B Q5_1 512 Yes pp512 7644.22 7655.42 1.00
RTX 4090 llama 8B Q5_K_S 16 Yes pp512 1194.83 1271.99 1.06
RTX 4090 llama 8B Q5_K_S 32 Yes pp512 1913.44 2047.03 1.07
RTX 4090 llama 8B Q5_K_S 64 Yes pp512 2695.83 2800.62 1.04
RTX 4090 llama 8B Q5_K_S 128 Yes pp512 3547.30 3544.97 1.00
RTX 4090 llama 8B Q5_K_S 256 Yes pp512 5789.94 5790.79 1.00
RTX 4090 llama 8B Q5_K_S 512 Yes pp512 7723.31 7725.42 1.00
RTX 4090 llama 8B Q6_K 16 Yes pp512 1096.41 1125.48 1.03
RTX 4090 llama 8B Q6_K 32 Yes pp512 1839.73 1783.85 0.97
RTX 4090 llama 8B Q6_K 64 Yes pp512 2574.44 2584.98 1.00
RTX 4090 llama 8B Q6_K 128 Yes pp512 3456.14 3457.61 1.00
RTX 4090 llama 8B Q6_K 256 Yes pp512 5640.23 5644.63 1.00
RTX 4090 llama 8B Q6_K 512 Yes pp512 7442.21 7445.44 1.00
RTX 4090 llama 8B Q8_0 16 Yes pp512 1036.27 1054.82 1.02
RTX 4090 llama 8B Q8_0 32 Yes pp512 1795.01 1973.04 1.10
RTX 4090 llama 8B Q8_0 64 Yes pp512 2988.63 3063.31 1.02
RTX 4090 llama 8B Q8_0 128 Yes pp512 3377.35 3373.79 1.00
RTX 4090 llama 8B Q8_0 256 Yes pp512 5574.59 5571.14 1.00
RTX 4090 llama 8B Q8_0 512 Yes pp512 7515.20 7524.32 1.00
RTX 3090 llama 8B Q2_K_M 16 Yes pp512 530.92 528.76 1.00
RTX 3090 llama 8B Q2_K_M 32 Yes pp512 630.30 621.68 0.99
RTX 3090 llama 8B Q2_K_M 64 Yes pp512 918.81 895.11 0.97
RTX 3090 llama 8B Q2_K_M 128 Yes pp512 1159.97 1138.32 0.98
RTX 3090 llama 8B Q2_K_M 256 Yes pp512 1502.26 1464.59 0.97
RTX 3090 llama 8B Q2_K_M 512 Yes pp512 1541.40 1562.12 1.01
RTX 3090 llama 8B Q3_K_S 16 Yes pp512 419.23 415.71 0.99
RTX 3090 llama 8B Q3_K_S 32 Yes pp512 539.39 527.62 0.98
RTX 3090 llama 8B Q3_K_S 64 Yes pp512 858.80 841.15 0.98
RTX 3090 llama 8B Q3_K_S 128 Yes pp512 1203.35 1170.51 0.97
RTX 3090 llama 8B Q3_K_S 256 Yes pp512 1555.93 1560.10 1.00
RTX 3090 llama 8B Q3_K_S 512 Yes pp512 1652.37 1663.21 1.01
RTX 3090 llama 8B Q4_0 16 Yes pp512 949.48 939.88 0.99
RTX 3090 llama 8B Q4_0 32 Yes pp512 1267.78 1264.83 1.00
RTX 3090 llama 8B Q4_0 64 Yes pp512 1737.34 1726.78 0.99
RTX 3090 llama 8B Q4_0 128 Yes pp512 2054.06 2052.54 1.00
RTX 3090 llama 8B Q4_0 256 Yes pp512 2491.75 2500.63 1.00
RTX 3090 llama 8B Q4_0 512 Yes pp512 2620.33 2604.94 0.99
RTX 3090 llama 8B Q4_1 16 Yes pp512 925.82 920.58 0.99
RTX 3090 llama 8B Q4_1 32 Yes pp512 1239.31 1217.35 0.98
RTX 3090 llama 8B Q4_1 64 Yes pp512 1623.27 1625.47 1.00
RTX 3090 llama 8B Q4_1 128 Yes pp512 1955.09 1963.57 1.00
RTX 3090 llama 8B Q4_1 256 Yes pp512 2352.57 2298.17 0.98
RTX 3090 llama 8B Q4_1 512 Yes pp512 2479.74 2461.23 0.99
RTX 3090 llama 8B Q4_K_S 16 Yes pp512 738.96 741.87 1.00
RTX 3090 llama 8B Q4_K_S 32 Yes pp512 980.60 984.92 1.00
RTX 3090 llama 8B Q4_K_S 64 Yes pp512 1466.61 1436.16 0.98
RTX 3090 llama 8B Q4_K_S 128 Yes pp512 1742.16 1737.24 1.00
RTX 3090 llama 8B Q4_K_S 256 Yes pp512 2123.83 2140.76 1.01
RTX 3090 llama 8B Q4_K_S 512 Yes pp512 2277.97 2307.83 1.01
RTX 3090 llama 8B Q5_0 16 Yes pp512 657.66 648.62 0.99
RTX 3090 llama 8B Q5_0 32 Yes pp512 935.54 916.87 0.98
RTX 3090 llama 8B Q5_0 64 Yes pp512 1359.57 1305.98 0.96
RTX 3090 llama 8B Q5_0 128 Yes pp512 1782.90 1758.54 0.99
RTX 3090 llama 8B Q5_0 256 Yes pp512 2245.18 2206.38 0.98
RTX 3090 llama 8B Q5_0 512 Yes pp512 2380.48 2394.16 1.01
RTX 3090 llama 8B Q5_1 16 Yes pp512 723.15 712.35 0.99
RTX 3090 llama 8B Q5_1 32 Yes pp512 938.94 949.47 1.01
RTX 3090 llama 8B Q5_1 64 Yes pp512 1368.83 1356.26 0.99
RTX 3090 llama 8B Q5_1 128 Yes pp512 1753.49 1657.81 0.95
RTX 3090 llama 8B Q5_1 256 Yes pp512 2186.09 2136.65 0.98
RTX 3090 llama 8B Q5_1 512 Yes pp512 2312.84 2291.97 0.99
RTX 3090 llama 8B Q5_K_S 16 Yes pp512 598.35 643.18 1.07
RTX 3090 llama 8B Q5_K_S 32 Yes pp512 856.19 851.94 1.00
RTX 3090 llama 8B Q5_K_S 64 Yes pp512 1308.06 1234.13 0.94
RTX 3090 llama 8B Q5_K_S 128 Yes pp512 1591.37 1593.33 1.00
RTX 3090 llama 8B Q5_K_S 256 Yes pp512 1994.02 1976.63 0.99
RTX 3090 llama 8B Q5_K_S 512 Yes pp512 2135.53 2143.79 1.00
RTX 3090 llama 8B Q6_K 16 Yes pp512 564.58 565.88 1.00
RTX 3090 llama 8B Q6_K 32 Yes pp512 825.75 822.66 1.00
RTX 3090 llama 8B Q6_K 64 Yes pp512 1230.89 1239.15 1.01
RTX 3090 llama 8B Q6_K 128 Yes pp512 1541.75 1536.46 1.00
RTX 3090 llama 8B Q6_K 256 Yes pp512 1927.92 1970.42 1.02
RTX 3090 llama 8B Q6_K 512 Yes pp512 2088.83 2116.00 1.01
RTX 3090 llama 8B Q8_0 16 Yes pp512 603.33 645.74 1.07
RTX 3090 llama 8B Q8_0 32 Yes pp512 982.85 984.64 1.00
RTX 3090 llama 8B Q8_0 64 Yes pp512 1430.97 1456.70 1.02
RTX 3090 llama 8B Q8_0 128 Yes pp512 1856.46 1876.76 1.01
RTX 3090 llama 8B Q8_0 256 Yes pp512 2325.70 2321.93 1.00
RTX 3090 llama 8B Q8_0 512 Yes pp512 2471.43 2497.23 1.01
RX 6800 llama 8B Q2_K_M 512 No pp512 157.08 163.49 1.04
RX 6800 llama 8B Q2_K_M 512 No pp512 193.74 194.98 1.01
RX 6800 llama 8B Q2_K_M 512 No pp512 249.34 243.46 0.98
RX 6800 llama 8B Q2_K_M 512 No pp512 288.20 277.75 0.96
RX 6800 llama 8B Q2_K_M 512 No pp512 332.58 324.61 0.98
RX 6800 llama 8B Q2_K_M 512 No pp512 358.45 353.03 0.98
RX 6800 llama 8B Q3_K_S 512 No pp512 109.38 110.86 1.01
RX 6800 llama 8B Q3_K_S 512 No pp512 146.56 147.55 1.01
RX 6800 llama 8B Q3_K_S 512 No pp512 216.45 199.18 0.92
RX 6800 llama 8B Q3_K_S 512 No pp512 244.60 223.69 0.91
RX 6800 llama 8B Q3_K_S 512 No pp512 287.36 262.28 0.91
RX 6800 llama 8B Q3_K_S 512 No pp512 308.01 284.77 0.92
RX 6800 llama 8B Q4_0 512 No pp512 329.44 350.80 1.06
RX 6800 llama 8B Q4_0 512 No pp512 424.81 447.06 1.05
RX 6800 llama 8B Q4_0 512 No pp512 488.25 564.61 1.16
RX 6800 llama 8B Q4_0 512 No pp512 559.76 649.50 1.16
RX 6800 llama 8B Q4_0 512 No pp512 616.12 698.43 1.13
RX 6800 llama 8B Q4_0 512 No pp512 660.46 744.04 1.13
RX 6800 llama 8B Q4_1 512 No pp512 313.23 332.41 1.06
RX 6800 llama 8B Q4_1 512 No pp512 395.47 410.29 1.04
RX 6800 llama 8B Q4_1 512 No pp512 447.77 521.64 1.16
RX 6800 llama 8B Q4_1 512 No pp512 516.29 604.12 1.17
RX 6800 llama 8B Q4_1 512 No pp512 575.27 653.40 1.14
RX 6800 llama 8B Q4_1 512 No pp512 615.71 696.65 1.13
RX 6800 llama 8B Q4_K_S 512 No pp512 279.38 288.81 1.03
RX 6800 llama 8B Q4_K_S 512 No pp512 346.94 343.54 0.99
RX 6800 llama 8B Q4_K_S 512 No pp512 391.23 417.93 1.07
RX 6800 llama 8B Q4_K_S 512 No pp512 432.29 473.87 1.10
RX 6800 llama 8B Q4_K_S 512 No pp512 495.26 531.17 1.07
RX 6800 llama 8B Q4_K_S 512 No pp512 531.88 572.07 1.08
RX 6800 llama 8B Q5_0 512 No pp512 254.52 265.04 1.04
RX 6800 llama 8B Q5_0 512 No pp512 370.51 386.30 1.04
RX 6800 llama 8B Q5_0 512 No pp512 458.79 518.40 1.13
RX 6800 llama 8B Q5_0 512 No pp512 532.04 599.45 1.13
RX 6800 llama 8B Q5_0 512 No pp512 582.78 648.47 1.11
RX 6800 llama 8B Q5_0 512 No pp512 620.75 688.35 1.11
RX 6800 llama 8B Q5_1 512 No pp512 266.91 272.65 1.02
RX 6800 llama 8B Q5_1 512 No pp512 363.78 373.59 1.03
RX 6800 llama 8B Q5_1 512 No pp512 437.54 488.39 1.12
RX 6800 llama 8B Q5_1 512 No pp512 515.97 574.23 1.11
RX 6800 llama 8B Q5_1 512 No pp512 572.46 627.92 1.10
RX 6800 llama 8B Q5_1 512 No pp512 612.17 671.81 1.10
RX 6800 llama 8B Q5_K_S 512 No pp512 262.77 265.41 1.01
RX 6800 llama 8B Q5_K_S 512 No pp512 348.23 356.84 1.02
RX 6800 llama 8B Q5_K_S 512 No pp512 412.91 412.09 1.00
RX 6800 llama 8B Q5_K_S 512 No pp512 464.64 463.63 1.00
RX 6800 llama 8B Q5_K_S 512 No pp512 526.68 523.03 0.99
RX 6800 llama 8B Q5_K_S 512 No pp512 562.90 562.30 1.00
RX 6800 llama 8B Q6_K 512 No pp512 235.89 240.19 1.02
RX 6800 llama 8B Q6_K 512 No pp512 303.47 309.41 1.02
RX 6800 llama 8B Q6_K 512 No pp512 368.56 359.06 0.97
RX 6800 llama 8B Q6_K 512 No pp512 412.96 398.44 0.96
RX 6800 llama 8B Q6_K 512 No pp512 471.93 457.72 0.97
RX 6800 llama 8B Q6_K 512 No pp512 502.60 489.37 0.97
RX 6800 llama 8B Q8_0 512 No pp512 282.22 296.78 1.05
RX 6800 llama 8B Q8_0 512 No pp512 397.59 422.71 1.06
RX 6800 llama 8B Q8_0 512 No pp512 525.22 544.64 1.04
RX 6800 llama 8B Q8_0 512 No pp512 618.31 634.40 1.03
RX 6800 llama 8B Q8_0 512 No pp512 675.13 692.29 1.03
RX 6800 llama 8B Q8_0 512 No pp512 723.71 738.57 1.02
P40 llama 8B Q2_K_M 512 Yes pp512 237.34 245.33 1.03
P40 llama 8B Q2_K_M 512 Yes pp512 346.03 365.32 1.06
P40 llama 8B Q2_K_M 512 Yes pp512 425.68 499.35 1.17
P40 llama 8B Q2_K_M 512 Yes pp512 514.15 572.58 1.11
P40 llama 8B Q2_K_M 512 Yes pp512 569.00 623.27 1.10
P40 llama 8B Q2_K_M 512 Yes pp512 599.91 653.08 1.09
P40 llama 8B Q3_K_S 512 Yes pp512 191.82 196.45 1.02
P40 llama 8B Q3_K_S 512 Yes pp512 308.71 323.82 1.05
P40 llama 8B Q3_K_S 512 Yes pp512 403.49 485.15 1.20
P40 llama 8B Q3_K_S 512 Yes pp512 494.54 555.72 1.12
P40 llama 8B Q3_K_S 512 Yes pp512 549.80 602.46 1.10
P40 llama 8B Q3_K_S 512 Yes pp512 580.61 634.97 1.09
P40 llama 8B Q4_0 512 Yes pp512 449.64 450.38 1.00
P40 llama 8B Q4_0 512 Yes pp512 631.98 643.95 1.02
P40 llama 8B Q4_0 512 Yes pp512 725.83 765.18 1.05
P40 llama 8B Q4_0 512 Yes pp512 846.09 891.05 1.05
P40 llama 8B Q4_0 512 Yes pp512 938.43 977.63 1.04
P40 llama 8B Q4_0 512 Yes pp512 977.94 1021.74 1.04
P40 llama 8B Q4_1 512 Yes pp512 449.09 459.07 1.02
P40 llama 8B Q4_1 512 Yes pp512 551.55 634.79 1.15
P40 llama 8B Q4_1 512 Yes pp512 709.22 746.86 1.05
P40 llama 8B Q4_1 512 Yes pp512 832.00 877.13 1.05
P40 llama 8B Q4_1 512 Yes pp512 924.13 959.59 1.04
P40 llama 8B Q4_1 512 Yes pp512 959.92 994.99 1.04
P40 llama 8B Q4_K_S 512 Yes pp512 401.29 412.25 1.03
P40 llama 8B Q4_K_S 512 Yes pp512 508.90 530.00 1.04
P40 llama 8B Q4_K_S 512 Yes pp512 656.57 684.11 1.04
P40 llama 8B Q4_K_S 512 Yes pp512 770.55 803.24 1.04
P40 llama 8B Q4_K_S 512 Yes pp512 844.89 887.66 1.05
P40 llama 8B Q4_K_S 512 Yes pp512 884.86 931.31 1.05
P40 llama 8B Q5_0 512 Yes pp512 340.35 343.69 1.01
P40 llama 8B Q5_0 512 Yes pp512 481.09 544.34 1.13
P40 llama 8B Q5_0 512 Yes pp512 683.82 690.55 1.01
P40 llama 8B Q5_0 512 Yes pp512 796.89 800.14 1.00
P40 llama 8B Q5_0 512 Yes pp512 870.45 886.13 1.02
P40 llama 8B Q5_0 512 Yes pp512 905.58 921.42 1.02
P40 llama 8B Q5_1 512 Yes pp512 381.26 380.18 1.00
P40 llama 8B Q5_1 512 Yes pp512 505.76 508.57 1.01
P40 llama 8B Q5_1 512 Yes pp512 675.41 697.82 1.03
P40 llama 8B Q5_1 512 Yes pp512 788.00 811.61 1.03
P40 llama 8B Q5_1 512 Yes pp512 864.09 887.09 1.03
P40 llama 8B Q5_1 512 Yes pp512 891.85 923.47 1.04
P40 llama 8B Q5_K_S 512 Yes pp512 307.56 361.05 1.17
P40 llama 8B Q5_K_S 512 Yes pp512 455.28 537.45 1.18
P40 llama 8B Q5_K_S 512 Yes pp512 630.29 658.55 1.04
P40 llama 8B Q5_K_S 512 Yes pp512 725.76 763.59 1.05
P40 llama 8B Q5_K_S 512 Yes pp512 800.00 840.82 1.05
P40 llama 8B Q5_K_S 512 Yes pp512 840.87 887.07 1.05
P40 llama 8B Q6_K 512 Yes pp512 290.65 337.85 1.16
P40 llama 8B Q6_K 512 Yes pp512 454.24 467.11 1.03
P40 llama 8B Q6_K 512 Yes pp512 608.92 636.37 1.05
P40 llama 8B Q6_K 512 Yes pp512 699.30 744.69 1.06
P40 llama 8B Q6_K 512 Yes pp512 777.11 816.01 1.05
P40 llama 8B Q6_K 512 Yes pp512 809.58 841.91 1.04
P40 llama 8B Q8_0 512 Yes pp512 327.28 328.31 1.00
P40 llama 8B Q8_0 512 Yes pp512 515.54 532.33 1.03
P40 llama 8B Q8_0 512 Yes pp512 681.38 671.99 0.99
P40 llama 8B Q8_0 512 Yes pp512 813.87 808.15 0.99
P40 llama 8B Q8_0 512 Yes pp512 902.11 894.44 0.99
P40 llama 8B Q8_0 512 Yes pp512 947.15 936.98 0.99

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jun 7, 2024
Copy link
Contributor

github-actions bot commented Jun 7, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 544 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8600.12ms p(95)=20560.13ms fails=, finish reason: stop=488 truncated=56
  • Prompt processing (pp): avg=103.63tk/s p(95)=467.48tk/s
  • Token generation (tg): avg=33.02tk/s p(95)=48.83tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=cuda-mmq-q8_1-2 commit=05a5fa08ce76c5d9cb6f1a9a08a6ac3d88d86029

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 544 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1717967238 --> 1717967874
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 570.61, 570.61, 570.61, 570.61, 570.61, 624.95, 624.95, 624.95, 624.95, 624.95, 637.78, 637.78, 637.78, 637.78, 637.78, 664.24, 664.24, 664.24, 664.24, 664.24, 733.46, 733.46, 733.46, 733.46, 733.46, 747.47, 747.47, 747.47, 747.47, 747.47, 749.3, 749.3, 749.3, 749.3, 749.3, 786.61, 786.61, 786.61, 786.61, 786.61, 786.32, 786.32, 786.32, 786.32, 786.32, 801.05, 801.05, 801.05, 801.05, 801.05, 830.99, 830.99, 830.99, 830.99, 830.99, 873.66, 873.66, 873.66, 873.66, 873.66, 888.96, 888.96, 888.96, 888.96, 888.96, 903.26, 903.26, 903.26, 903.26, 903.26, 902.76, 902.76, 902.76, 902.76, 902.76, 905.97, 905.97, 905.97, 905.97, 905.97, 906.28, 906.28, 906.28, 906.28, 906.28, 898.93, 898.93, 898.93, 898.93, 898.93, 900.85, 900.85, 900.85, 900.85, 900.85, 892.45, 892.45, 892.45, 892.45, 892.45, 896.02, 896.02, 896.02, 896.02, 896.02, 895.76, 895.76, 895.76, 895.76, 895.76, 895.94, 895.94, 895.94, 895.94, 895.94, 890.11, 890.11, 890.11, 890.11, 890.11, 854.82, 854.82, 854.82, 854.82, 854.82, 854.5, 854.5, 854.5, 854.5, 854.5, 863.65, 863.65, 863.65, 863.65, 863.65, 864.34, 864.34, 864.34, 864.34, 864.34, 864.09, 864.09, 864.09, 864.09, 864.09, 863.94, 863.94, 863.94, 863.94, 863.94, 869.16, 869.16, 869.16, 869.16, 869.16, 868.76, 868.76, 868.76, 868.76, 868.76, 870.19, 870.19, 870.19, 870.19, 870.19, 876.49, 876.49, 876.49, 876.49, 876.49, 881.75, 881.75, 881.75, 881.75, 881.75, 885.75, 885.75, 885.75, 885.75, 885.75, 860.81, 860.81, 860.81, 860.81, 860.81, 859.99, 859.99, 859.99, 859.99, 859.99, 858.09, 858.09, 858.09, 858.09, 858.09, 860.65, 860.65, 860.65, 860.65, 860.65, 860.16, 860.16, 860.16, 860.16, 860.16, 843.95, 843.95, 843.95, 843.95, 843.95, 814.92, 814.92, 814.92, 814.92, 814.92, 816.18, 816.18, 816.18, 816.18, 816.18, 811.08, 811.08, 811.08, 811.08, 811.08, 811.77, 811.77, 811.77, 811.77, 811.77, 817.65, 817.65, 817.65, 817.65, 817.65, 817.19, 817.19, 817.19, 817.19, 817.19, 823.5, 823.5, 823.5, 823.5, 823.5, 823.39, 823.39, 823.39, 823.39, 823.39, 824.8, 824.8, 824.8, 824.8, 824.8, 822.08, 822.08, 822.08, 822.08, 822.08, 821.78, 821.78, 821.78, 821.78, 821.78, 825.47, 825.47, 825.47, 825.47, 825.47, 825.06, 825.06, 825.06, 825.06, 825.06, 825.61, 825.61, 825.61, 825.61, 825.61, 826.06, 826.06, 826.06, 826.06, 826.06, 827.89, 827.89, 827.89, 827.89, 827.89, 830.7, 830.7, 830.7, 830.7, 830.7, 830.99, 830.99, 830.99, 830.99, 830.99, 833.24, 833.24, 833.24, 833.24]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 544 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1717967238 --> 1717967874
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 42.51, 42.51, 42.51, 42.51, 42.51, 42.47, 42.47, 42.47, 42.47, 42.47, 33.16, 33.16, 33.16, 33.16, 33.16, 30.85, 30.85, 30.85, 30.85, 30.85, 31.58, 31.58, 31.58, 31.58, 31.58, 32.04, 32.04, 32.04, 32.04, 32.04, 33.54, 33.54, 33.54, 33.54, 33.54, 33.58, 33.58, 33.58, 33.58, 33.58, 34.08, 34.08, 34.08, 34.08, 34.08, 34.04, 34.04, 34.04, 34.04, 34.04, 34.07, 34.07, 34.07, 34.07, 34.07, 34.12, 34.12, 34.12, 34.12, 34.12, 33.17, 33.17, 33.17, 33.17, 33.17, 33.06, 33.06, 33.06, 33.06, 33.06, 31.71, 31.71, 31.71, 31.71, 31.71, 30.1, 30.1, 30.1, 30.1, 30.1, 29.78, 29.78, 29.78, 29.78, 29.78, 29.73, 29.73, 29.73, 29.73, 29.73, 29.92, 29.92, 29.92, 29.92, 29.92, 29.86, 29.86, 29.86, 29.86, 29.86, 29.91, 29.91, 29.91, 29.91, 29.91, 30.0, 30.0, 30.0, 30.0, 30.0, 30.17, 30.17, 30.17, 30.17, 30.17, 30.34, 30.34, 30.34, 30.34, 30.34, 30.42, 30.42, 30.42, 30.42, 30.42, 30.52, 30.52, 30.52, 30.52, 30.52, 30.61, 30.61, 30.61, 30.61, 30.61, 30.44, 30.44, 30.44, 30.44, 30.44, 30.42, 30.42, 30.42, 30.42, 30.42, 30.58, 30.58, 30.58, 30.58, 30.58, 30.86, 30.86, 30.86, 30.86, 30.86, 30.89, 30.89, 30.89, 30.89, 30.89, 31.11, 31.11, 31.11, 31.11, 31.11, 31.21, 31.21, 31.21, 31.21, 31.21, 31.13, 31.13, 31.13, 31.13, 31.13, 30.9, 30.9, 30.9, 30.9, 30.9, 30.71, 30.71, 30.71, 30.71, 30.71, 30.88, 30.88, 30.88, 30.88, 30.88, 30.99, 30.99, 30.99, 30.99, 30.99, 31.14, 31.14, 31.14, 31.14, 31.14, 31.28, 31.28, 31.28, 31.28, 31.28, 31.29, 31.29, 31.29, 31.29, 31.29, 31.24, 31.24, 31.24, 31.24, 31.24, 30.24, 30.24, 30.24, 30.24, 30.24, 29.73, 29.73, 29.73, 29.73, 29.73, 29.64, 29.64, 29.64, 29.64, 29.64, 29.62, 29.62, 29.62, 29.62, 29.62, 29.63, 29.63, 29.63, 29.63, 29.63, 29.76, 29.76, 29.76, 29.76, 29.76, 29.82, 29.82, 29.82, 29.82, 29.82, 29.83, 29.83, 29.83, 29.83, 29.83, 29.85, 29.85, 29.85, 29.85, 29.85, 29.8, 29.8, 29.8, 29.8, 29.8, 29.72, 29.72, 29.72, 29.72, 29.72, 29.7, 29.7, 29.7, 29.7, 29.7, 29.76, 29.76, 29.76, 29.76, 29.76, 29.83, 29.83, 29.83, 29.83, 29.83, 30.03, 30.03, 30.03, 30.03, 30.03, 30.12, 30.12, 30.12, 30.12, 30.12, 30.13, 30.13, 30.13, 30.13, 30.13, 30.09, 30.09, 30.09, 30.09]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 544 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1717967238 --> 1717967874
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.25, 0.38, 0.38, 0.38, 0.38, 0.38, 0.26, 0.26, 0.26, 0.26, 0.26, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.1, 0.1, 0.1, 0.1, 0.1, 0.16, 0.16, 0.16, 0.16, 0.16, 0.18, 0.18, 0.18, 0.18, 0.18, 0.17, 0.17, 0.17, 0.17, 0.17, 0.19, 0.19, 0.19, 0.19, 0.19, 0.22, 0.22, 0.22, 0.22, 0.22, 0.28, 0.28, 0.28, 0.28, 0.28, 0.3, 0.3, 0.3, 0.3, 0.3, 0.43, 0.43, 0.43, 0.43, 0.43, 0.33, 0.33, 0.33, 0.33, 0.33, 0.25, 0.25, 0.25, 0.25, 0.25, 0.23, 0.23, 0.23, 0.23, 0.23, 0.11, 0.11, 0.11, 0.11, 0.11, 0.22, 0.22, 0.22, 0.22, 0.22, 0.2, 0.2, 0.2, 0.2, 0.2, 0.22, 0.22, 0.22, 0.22, 0.22, 0.16, 0.16, 0.16, 0.16, 0.16, 0.21, 0.21, 0.21, 0.21, 0.21, 0.28, 0.28, 0.28, 0.28, 0.28, 0.11, 0.11, 0.11, 0.11, 0.11, 0.19, 0.19, 0.19, 0.19, 0.19, 0.26, 0.26, 0.26, 0.26, 0.26, 0.23, 0.23, 0.23, 0.23, 0.23, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.12, 0.12, 0.12, 0.12, 0.12, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.28, 0.28, 0.28, 0.28, 0.28, 0.26, 0.26, 0.26, 0.26, 0.26, 0.27, 0.27, 0.27, 0.27, 0.27, 0.1, 0.1, 0.1, 0.1, 0.1, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.18, 0.18, 0.18, 0.18, 0.18, 0.39, 0.39, 0.39, 0.39, 0.39, 0.48, 0.48, 0.48, 0.48, 0.48, 0.49, 0.49, 0.49, 0.49, 0.49, 0.35, 0.35, 0.35, 0.35, 0.35, 0.14, 0.14, 0.14, 0.14, 0.14, 0.22, 0.22, 0.22, 0.22, 0.22, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.22, 0.22, 0.22, 0.22, 0.22, 0.25, 0.25, 0.25, 0.25, 0.25, 0.27, 0.27, 0.27, 0.27, 0.27, 0.25, 0.25, 0.25, 0.25, 0.25, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.11, 0.11, 0.11, 0.11, 0.11, 0.08, 0.08, 0.08, 0.08, 0.08, 0.15, 0.15, 0.15, 0.15, 0.15, 0.24, 0.24, 0.24, 0.24, 0.24, 0.37, 0.37, 0.37, 0.37]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 544 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1717967238 --> 1717967874
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 1.0, 1.0, 1.0, 1.0]
                    
Loading

ggml-cuda.cu Outdated
@@ -1347,10 +1347,30 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
GGML_UNUSED(main_device);
}

static cudaError_t cudaMemcpy2DPeerAsync(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not good to "impersonate" the CUDA API, it can be confusing and it may cause conflicts in future CUDA versions. This should follow the same naming convention as the rest of the CUDA backend functions, eg. ggml_cuda_xxx.

const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
const ggml_type type_x, cudaStream_t stream);

void quantize_row_q8_1_cuda(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a big problem at the moment, but it is not great to export symbols with names without any kind of prefix or namespace, since it can easily lead to conflicts with other code. It would probably be easier to just move all the code in the CUDA backend to a namespace, but it's not important right now.

@JohannesGaessler JohannesGaessler merged commit 42b53d1 into ggml-org:master Jun 9, 2024
56 of 69 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants