@@ -32,35 +32,37 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
32
32
33
33
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
34
34
35
+ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
36
+
35
37
/*
36
38
* Computes a depthwise convolution. Each shader invocation calculates the
37
39
* output at a single output location.
38
40
*/
39
41
void main() {
40
- const ivec3 pos = ivec3 (gl_GlobalInvocationID);
42
+ const u16vec3 pos = u16vec3 (gl_GlobalInvocationID);
41
43
42
44
if (any (greaterThanEqual (pos, out_limits))) {
43
45
return ;
44
46
}
45
47
46
48
// Compute the index of the top-left element of the overlay region. Negative
47
49
// indices indicate that the top-left element is in a region added by padding.
48
- const ivec2 ipos = pos.xy * stride - padding;
50
+ const u16vec2 ipos = pos.xy * u16vec2( stride) - u16vec2( padding) ;
49
51
50
52
// Compute the start and end of the input indices to load. Padding is assumed
51
53
// to be constant 0 padding, so any reads from the padding region is skipped.
52
- const ivec2 start = ipos;
53
- const ivec2 end = ipos + overlay_region.xy;
54
+ const u16vec2 start = ipos;
55
+ const u16vec2 end = ipos + u16vec2( overlay_region.xy) ;
54
56
55
- VEC4_T sum = texelFetch(t_bias, ivec2 (pos.z, 0 ), 0 );
56
- int kx = 0 ;
57
- for (int y = start.y, i = 0 ; i < TILE_SIZE; y += dilation.y, i++ ) {
58
- for (int x = start.x, j = 0 ; j < TILE_SIZE; x += dilation.x, j++ ) {
57
+ VEC4_T sum = texelFetch(t_bias, u16vec2 (pos.z, 0 ), 0 );
58
+ uint16_t kx = uint16_t( 0 ) ;
59
+ for (uint16_t y = start.y, i = uint16_t( 0 ) ; i < uint16_t( TILE_SIZE) ; y += uint16_t( dilation.y) , i++ ) {
60
+ for (uint16_t x = start.x, j = uint16_t( 0 ) ; j < uint16_t( TILE_SIZE) ; x += uint16_t( dilation.x) , j++ ) {
59
61
// The weight kernel was rearranged such that every NxN filter is
60
62
// flattened to fit in one row. Each filter was then stacked on top of
61
63
// each other vertically.
62
- const vec4 in_texel = texelFetch(t_in, ivec3 (x, y, pos.z), 0 );
63
- sum = fma(in_texel, texelFetch(t_kernel, ivec2 (kx, pos.z), 0 ), sum);
64
+ const vec4 in_texel = texelFetch(t_in, u16vec3 (x, y, pos.z), 0 );
65
+ sum = fma(in_texel, texelFetch(t_kernel, u16vec2 (kx, pos.z), 0 ), sum);
64
66
kx++ ;
65
67
}
66
68
}
0 commit comments