10
10
11
11
#define PRECISION ${PRECISION}
12
12
13
- ${define_required_extensions("uint8")}
14
- ${define_required_extensions("int8")}
13
+ $if not NO_INT8_BUFFERS:
14
+ ${define_required_extensions("uint8")}
15
+ $if STORAGE == "buffer ":
16
+ ${define_required_extensions("int8")}
15
17
16
18
layout (std430) buffer ;
17
19
18
20
${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array= False)}
19
- ${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer ")}
21
+ $if NO_INT8_BUFFERS:
22
+ ${layout_declare_tensor(B, "r", "nchw_4x2", "uint ", "buffer ")}
23
+ $else :
24
+ ${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer ")}
20
25
21
26
layout (push_constant) uniform restrict Block {
22
27
ivec4 qmat2_sizes;
23
28
};
24
29
25
30
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
26
31
27
- uint8_t get_first(const uint8_t packed) {
28
- return uint8_t((packed & 0xF0) >> 4 );
32
+ $if NO_INT8_BUFFERS:
33
+ #define BUF_T uint
34
+ $else :
35
+ #define BUF_T uint8_t
36
+
37
+ $if STORAGE == "buffer ":
38
+ #define UVEC4_T u8vec4
39
+ $else :
40
+ #define UVEC4_T uvec4
41
+
42
+ uint get_first(const BUF_T packed) {
43
+ return (packed & 0xF0) >> 4 ;
29
44
}
30
45
31
- uint8_t get_second(const uint8_t packed) {
32
- return uint8_t( packed & 0x0F) ;
46
+ uint get_second(const BUF_T packed) {
47
+ return packed & 0x0F;
33
48
}
34
49
35
- uint8_t combine(const uint8_t first, const uint8_t second) {
36
- return uint8_t (first << 4 | second);
50
+ uint combine(const uint first, const uint second) {
51
+ return (first << 4 | second);
37
52
}
38
53
39
- /*
40
- * This shader packs the weight tensor into a texture.
41
- *
42
- * The original tensor has a (W, H) shape of (K / 2, N) and each scalar element
43
- * is a uint8_t, which contains 2 packed 4 bit uint values.
44
- *
45
- * The transform performed by this shader is to first transpose the tensor, so
46
- * the shape of the packed tensor becomes (N / 2, K). Then, the 4 bit integers
47
- * are re-packed in groups of 8. For each 4 uint8_t values, the "left" 4-bits
48
- * of each value contain the 0, 1, 2, 3 4-bit values, and the "right" 4-bits of
49
- * each value contain the 4, 5, 6, 7 4-bit values.
50
- *
51
- * As a concrete example, consider the following weight tensor. The | demarks
52
- * the packing boundary, so 1| 2 represents a single uint8_t value with 1 in the
53
- * leftmost 4 bits and 2 in the rightmost 4 bits.
54
- *
55
- * 1| 2, 3| 4, 5| 6, 7| 8,
56
- * 9|10, 11|12, 13|14, 15|16,
57
- * 17|18, 19|20, 21|22, 23|24,
58
- * 25|26, 27|28, 29|30, 31|32,
59
- * 33|34, 35|36, 37|38, 39|40,
60
- * 41|42, 43|44, 45|46, 47|48,
61
- * 49|50, 51|52, 53|54, 55|56,
62
- * 57|58, 59|60, 61|62, 63|64,
63
- *
64
- * After packing, the packed tensor would contain
65
- *
66
- * 1|33, 9|41, 17|49, 25|57,
67
- * 2|34, 10|42, 18|50, 26|58,
68
- * 3|35, 11|43, 19|51, 27|59,
69
- * 4|36, 12|44, 20|52, 28|60,
70
- * 5|37, 13|45, 21|53, 29|61,
71
- * 6|38, 14|46, 22|54, 30|62,
72
- * 7|39, 15|47, 23|55, 31|63,
73
- * 8|40, 16|48, 24|56, 32|64,
74
- *
75
- * The purpose of interleaving is to make it easier to extract the unpacked
76
- * values in order using the u8vec4 vectorized type. With the packing in place,
77
- * The 4-bit values can be extracted via
78
- *
79
- * u8vec4 packed;
80
- * u8vec4 vals_0123 = (packed & 0xF0) >> 4;
81
- * u8vec4 vals_4567 = (packed | 0x0F);
82
- */
54
+ $if NO_INT8_BUFFERS:
55
+ uint extract_comp(const uint packed4, const uint idx) {
56
+ return (packed4 >> (idx * 8 )) & 0xFF;
57
+ }
58
+
83
59
void main() {
84
60
// Each thread writes 2 output texels along the height axis
85
61
ivec2 packed_pos = ivec2 (
@@ -102,25 +78,32 @@ void main() {
102
78
int in_numcols = qmat2_sizes.y;
103
79
int in_num_int8_cols = qmat2_sizes.y >> 1 ;
104
80
105
- uint8_t in_vals[8 ][2 ];
81
+ uint in_vals[8 ][2 ];
106
82
for (int r = 0 ; r < 8 ; ++ r) {
107
83
if (in_row + r < in_numrows) {
108
- uint8_t in_val_packed = nchw_4x2[(in_row + r) * in_num_int8_cols + in_int8_col];
84
+ uint scalar_idx = (in_row + r) * in_num_int8_cols + in_int8_col;
85
+ $if NO_INT8_BUFFERS:
86
+ BUF_T in_val_packed_texel = nchw_4x2[scalar_idx >> 2 ];
87
+ const uint packed_idx = scalar_idx % 4 ;
88
+ uint in_val_packed = extract_comp(in_val_packed_texel, packed_idx);
89
+ $else :
90
+ BUF_T in_val_packed = nchw_4x2[scalar_idx];
91
+
109
92
in_vals[r][0 ] = get_first(in_val_packed);
110
93
in_vals[r][1 ] = get_second(in_val_packed);
111
94
} else {
112
- in_vals[r][0 ] = uint8_t (0 );
113
- in_vals[r][1 ] = uint8_t (0 );
95
+ in_vals[r][0 ] = uint (0 );
96
+ in_vals[r][1 ] = uint (0 );
114
97
}
115
98
}
116
99
117
- u8vec4 out_tex_1 = u8vec4 (
100
+ UVEC4_T out_tex_1 = UVEC4_T (
118
101
combine(in_vals[0 ][0 ], in_vals[4 ][0 ]),
119
102
combine(in_vals[1 ][0 ], in_vals[5 ][0 ]),
120
103
combine(in_vals[2 ][0 ], in_vals[6 ][0 ]),
121
104
combine(in_vals[3 ][0 ], in_vals[7 ][0 ]));
122
105
123
- u8vec4 out_tex_2 = u8vec4 (
106
+ UVEC4_T out_tex_2 = UVEC4_T (
124
107
combine(in_vals[0 ][1 ], in_vals[4 ][1 ]),
125
108
combine(in_vals[1 ][1 ], in_vals[5 ][1 ]),
126
109
combine(in_vals[2 ][1 ], in_vals[6 ][1 ]),
0 commit comments