@@ -1651,46 +1651,81 @@ static void ggml_vk_load_shaders(vk_device& device) {
1651
1651
1652
1652
// Create 2 variants, {f16,f32} accumulator
1653
1653
#define CREATE_MM2 (PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID ) \
1654
- CREATE_MM (PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1655
- CREATE_MM (PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1654
+ if (device->coopmat_acc_f16_support ) { \
1655
+ CREATE_MM (PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1656
+ } \
1657
+ if (device->coopmat_acc_f32_support ) { \
1658
+ CREATE_MM (PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1659
+ } \
1656
1660
1657
1661
CREATE_MM (pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
1658
1662
CREATE_MM (pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
1659
1663
CREATE_MM2 (pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
1660
1664
CREATE_MM2 (pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
1661
1665
1662
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc , matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1663
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc , matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1664
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc , matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1665
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc , matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1666
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc , matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1666
+ if (device->coopmat_acc_f16_support ) {
1667
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc , matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1668
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc , matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1669
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc , matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1670
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc , matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1671
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc , matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1672
+
1673
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc , matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1674
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc , matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1675
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc , matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1676
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc , matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1677
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc , matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1678
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc , matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1679
+ } else {
1680
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc , matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1681
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc , matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1682
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc , matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1683
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc , matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1684
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc , matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1667
1685
1668
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc , matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1669
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc , matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1670
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc , matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1671
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc , matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1672
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc , matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1673
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc , matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1686
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc , matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1687
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc , matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1688
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc , matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1689
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc , matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1690
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc , matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1691
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc , matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1692
+ }
1674
1693
1675
1694
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1676
1695
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l ) {
1677
1696
CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 , _id);
1678
1697
CREATE_MM2 (pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4 , _id);
1679
1698
CREATE_MM2 (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4 , _id);
1680
1699
1681
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1682
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1683
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1684
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1685
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1686
-
1687
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1688
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1689
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1690
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1691
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1692
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1700
+ if (device->coopmat_acc_f16_support ) {
1701
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1702
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1703
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1704
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1705
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1706
+
1707
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1708
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1709
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1710
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1711
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1712
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1713
+ } else {
1714
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1715
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1716
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1717
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1718
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1719
+
1720
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1721
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1722
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1723
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1724
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1725
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1726
+ }
1693
1727
}
1728
+ #undef CREATE_MM2
1694
1729
#undef CREATE_MM
1695
1730
} else if (device->fp16 ) {
1696
1731
// Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -1708,6 +1743,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
1708
1743
if (device->mul_mat ## ID ## _s) \
1709
1744
ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->a_s , #NAMELC #F16ACC " _aligned_s" , NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, " main" , PARAMCOUNT, sizeof (PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1710
1745
1746
+ // Create 2 variants, {f16,f32} accumulator
1747
+ #define CREATE_MM2 (PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID ) \
1748
+ CREATE_MM (PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1749
+ CREATE_MM (PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1750
+
1711
1751
CREATE_MM (pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
1712
1752
CREATE_MM (pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
1713
1753
CREATE_MM2 (pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
@@ -1745,6 +1785,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1745
1785
CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1746
1786
CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1747
1787
}
1788
+ #undef CREATE_MM2
1748
1789
#undef CREATE_MM
1749
1790
} else {
1750
1791
// Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -1799,7 +1840,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
1799
1840
CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc , matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1800
1841
CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc , matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1801
1842
}
1802
- #undef CREATE_MM2
1803
1843
#undef CREATE_MM
1804
1844
}
1805
1845
@@ -2096,11 +2136,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
2096
2136
2097
2137
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2098
2138
2099
- // if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2100
- // // Intel drivers don't support coopmat properly yet
2101
- // // Only RADV supports coopmat properly on AMD
2102
- // device->coopmat_support = false;
2103
- // }
2139
+ if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2140
+ // Intel drivers don't support coopmat properly yet
2141
+ // Only RADV supports coopmat properly on AMD
2142
+ device->coopmat_support = false ;
2143
+ }
2104
2144
2105
2145
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device .getQueueFamilyProperties ();
2106
2146
@@ -2191,8 +2231,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
2191
2231
device->pipeline_robustness = pl_robustness_features.pipelineRobustness ;
2192
2232
2193
2233
device->subgroup_size_control = device->subgroup_size_control &&
2194
- (!( subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) ||
2195
- ! subgroup_size_control_features.subgroupSizeControl ) ;
2234
+ (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
2235
+ subgroup_size_control_features.subgroupSizeControl ;
2196
2236
2197
2237
if (device->subgroup_size_control ) {
2198
2238
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize ;
@@ -2350,7 +2390,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2350
2390
}
2351
2391
}
2352
2392
2353
- if (device->coopmat_m == 0 ) {
2393
+ if (device->coopmat_m == 0 || !device-> coopmat_acc_f32_support ) {
2354
2394
// No suitable matmul mode found
2355
2395
GGML_LOG_DEBUG (" ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n " );
2356
2396
device->coopmat_support = false ;
@@ -2483,11 +2523,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2483
2523
}
2484
2524
}
2485
2525
2486
- // if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2487
- // // Intel drivers don't support coopmat properly yet
2488
- // // Only RADV supports coopmat properly on AMD
2489
- // coopmat_support = false;
2490
- // }
2526
+ if (props2.properties .vendorID == VK_VENDOR_ID_INTEL || (props2.properties .vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2527
+ // Intel drivers don't support coopmat properly yet
2528
+ // Only RADV supports coopmat properly on AMD
2529
+ coopmat_support = false ;
2530
+ }
2491
2531
2492
2532
const char * GGML_VK_DISABLE_F16 = getenv (" GGML_VK_DISABLE_F16" );
2493
2533
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr ;
@@ -2770,7 +2810,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
2770
2810
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
2771
2811
return ctx->device ->pipeline_matmul_f32_f16 ;
2772
2812
}
2773
- if (prec == GGML_PREC_DEFAULT && ctx->device ->fp16 ) {
2813
+ if (prec == GGML_PREC_DEFAULT && ctx->device ->fp16 && !(ctx-> device -> coopmat_support && !ctx-> device -> coopmat_acc_f16_support ) ) {
2774
2814
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2775
2815
return ctx->device ->pipeline_matmul_f16_f32 .f16acc ;
2776
2816
}
@@ -2845,7 +2885,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
2845
2885
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
2846
2886
return ctx->device ->pipeline_matmul_id_f32 ;
2847
2887
}
2848
- if (prec == GGML_PREC_DEFAULT && ctx->device ->fp16 ) {
2888
+ if (prec == GGML_PREC_DEFAULT && ctx->device ->fp16 && !(ctx-> device -> coopmat_support && !ctx-> device -> coopmat_acc_f16_support ) ) {
2849
2889
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2850
2890
return ctx->device ->pipeline_matmul_id_f16_f32 .f16acc ;
2851
2891
}
0 commit comments