@@ -108,27 +108,15 @@ void add_slice_tensor_copy_node(
108
108
spec_vars));
109
109
110
110
} else {
111
- // GPU's coordinate is in x, y, z
112
- int64_t gpu_dim = -1 ;
113
- int64_t in_channel_stride = 1 ;
114
- if (dim_index == kWidth4D ) {
115
- gpu_dim = 0 ; // width: x dimension in gpu
116
- VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
117
- } else if (dim_index == kHeight4D ) {
118
- gpu_dim = 1 ; // height: y dimension
119
- VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
120
- } else if (dim_index == kChannel4D ) {
121
- gpu_dim = 2 ; // channel: z dimension
122
- VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
123
- in_channel_stride = dim_at (in_sizes, kChannel4D );
124
- } else {
125
- gpu_dim = 3 ; // batch: w dimension
126
-
127
- in_channel_stride = dim_at (in_sizes, kChannel4D );
128
- if (packed_dim_idx == kChannel4D ) {
129
- // Due to channel packing, each batch value is span over stride planes
130
- in_channel_stride = utils::div_up_4 (in_channel_stride);
131
- }
111
+ // GPU's coordinate is in x = 0, y = 1, z = 2, w = 3
112
+ const int64_t gpu_dim = -(dim_index + 1 );
113
+ // stride of input tensor's channel dimension
114
+ int64_t in_channel_stride = dim_at (in_sizes, kChannel4D );
115
+ VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
116
+
117
+ // Due to channel packing, each batch value is span over stride planes
118
+ if (dim_index == kBatch4D && packed_dim_idx == kChannel4D ) {
119
+ in_channel_stride = utils::div_up_4 (in_channel_stride);
132
120
}
133
121
134
122
std::string kernel_name = " slice_batch_height_width" ;
0 commit comments