Skip to content

Commit 02630bc

Browse files
committed
[ET-VK][15/n] reconcile Dim4D and NchwDim
TSIA. Differential Revision: [D56731155](https://our.internmc.facebook.com/intern/diff/D56731155/) [ghstack-poisoned]
1 parent 8178226 commit 02630bc

File tree

7 files changed

+74
-93
lines changed

7 files changed

+74
-93
lines changed

backends/vulkan/runtime/graph/ops/impl/Cat.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ void add_cat_default_node(
3131
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
3232
vTensorPtr t_out = graph.get_tensor(out);
3333

34-
NchwDim nchw_dim = normalize_to_nchw_dim(*t_out, dim);
34+
Dim4DType dim4d = normalize_to_dim4d(*t_out, dim);
3535

3636
// TODO: Find ways to factor out the similar code for width, height, and batch
37-
if (nchw_dim == DimWidth) {
37+
if (dim4d == DIM4D_WIDTH) {
3838
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
3939
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
4040

@@ -46,7 +46,7 @@ void add_cat_default_node(
4646
dst_offset.data[0] += range.data[0];
4747
}
4848

49-
} else if (nchw_dim == DimHeight) {
49+
} else if (dim4d == DIM4D_HEIGHT) {
5050
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
5151
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
5252

@@ -57,7 +57,7 @@ void add_cat_default_node(
5757
graph, input_ref, range, src_offset, dst_offset, out);
5858
dst_offset.data[1] += range.data[1];
5959
}
60-
} else if (nchw_dim == DimBatch) {
60+
} else if (dim4d == DIM4D_BATCH) {
6161
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
6262
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
6363

@@ -68,19 +68,19 @@ void add_cat_default_node(
6868
graph, input_ref, range, src_offset, dst_offset, out);
6969
dst_offset.data[2] += range.data[2];
7070
}
71-
} else if (nchw_dim == DimChannel) {
71+
} else if (dim4d == DIM4D_CHANNEL) {
7272
int32_t src_offset = 0;
7373
int32_t dst_offset = 0;
7474

7575
for (ValueRef input_ref : *input_list) {
7676
vTensorPtr t_in = graph.get_tensor(input_ref);
77-
int32_t range = dim_at<Dim4D::Channel>(t_in->sizes());
77+
int32_t range = dim_at(t_in->sizes(), DIM4D_CHANNEL);
7878
add_copy_channel_offset_node(
7979
graph, input_ref, range, src_offset, dst_offset, out);
8080
dst_offset += range;
8181
}
8282
} else {
83-
VK_THROW("Unexpected value of nchw_dim=", nchw_dim);
83+
VK_THROW("Unexpected value of dim4d=", dim4d);
8484
}
8585
}
8686

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,23 +92,23 @@ void add_copy_channel_offset_node(
9292
VK_CHECK_COND(t_out->dim() >= 3, "Dst dim should be at least 3");
9393

9494
VK_CHECK_COND(
95-
dim_at<Dim4D::Channel>(in_sizes) >= src_channel_offset + channel_range,
95+
dim_at<DIM4D_CHANNEL>(in_sizes) >= src_channel_offset + channel_range,
9696
"Src channel (",
9797
src_channel_offset,
9898
") and range (",
9999
channel_range,
100100
") should be less than or equal to input tensor's channel size (",
101-
dim_at<Dim4D::Channel>(in_sizes),
101+
dim_at<DIM4D_CHANNEL>(in_sizes),
102102
")");
103103

104104
VK_CHECK_COND(
105-
dim_at<Dim4D::Channel>(out_sizes) >= dst_channel_offset + channel_range,
105+
dim_at<DIM4D_CHANNEL>(out_sizes) >= dst_channel_offset + channel_range,
106106
"Dst channel (",
107107
dst_channel_offset,
108108
") and range (",
109109
channel_range,
110110
") should be less than or equal to input tensor's channel size (",
111-
dim_at<Dim4D::Channel>(out_sizes),
111+
dim_at<DIM4D_CHANNEL>(out_sizes),
112112
")");
113113

114114
VK_CHECK_COND(channel_range >= 0, "Channel range must be non-negative");
@@ -121,10 +121,10 @@ void add_copy_channel_offset_node(
121121
kernel_name.reserve(kShaderNameReserve);
122122
add_dtype_suffix(kernel_name, *t_out);
123123

124-
int32_t out_channels = dim_at<Dim4D::Channel>(out_sizes);
124+
int32_t out_channels = dim_at<DIM4D_CHANNEL>(out_sizes);
125125

126126
// Copy one batch at a time.
127-
for (int batch_idx = 0; batch_idx < dim_at<Dim4D::Batch>(in_sizes);
127+
for (int batch_idx = 0; batch_idx < dim_at<DIM4D_BATCH>(in_sizes);
128128
batch_idx++) {
129129
// Mapping the tensor NCHW coordinates into texture XYZ coordinates
130130
int32_t dst_first_z = dst_channel_offset / 4;
@@ -139,8 +139,8 @@ void add_copy_channel_offset_node(
139139
0, 0, dst_first_z + batch_idx * api::utils::div_up(out_channels, 4)};
140140

141141
uvec3 global_size{
142-
dim_at<Dim4D::Width>(in_sizes),
143-
dim_at<Dim4D::Height>(in_sizes),
142+
dim_at<DIM4D_WIDTH>(in_sizes),
143+
dim_at<DIM4D_HEIGHT>(in_sizes),
144144
api::utils::safe_downcast<uint32_t>(dst_last_z - dst_first_z + 1)};
145145

146146
uvec3 local_size = adaptive_work_group_size(global_size);

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ void add_permute_node(
7070
kernel_name.reserve(kShaderNameReserve);
7171
add_dtype_suffix(kernel_name, *t_out);
7272

73-
uint32_t out_channels = dim_at<Dim4D::Channel>(t_out->sizes());
74-
uint32_t in_channels = dim_at<Dim4D::Channel>(t_in->sizes());
73+
uint32_t out_channels = dim_at<DIM4D_CHANNEL>(t_out->sizes());
74+
uint32_t in_channels = dim_at<DIM4D_CHANNEL>(t_in->sizes());
7575

7676
uint32_t out_c_aligned = api::utils::align_up(out_channels, 4u);
7777
uint32_t in_c_aligned = api::utils::align_up(in_channels, 4u);

backends/vulkan/runtime/graph/ops/impl/Repeat.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,23 @@ void check_args(
3232
"Input tensor dim size must be not greater than the repeat argument's size");
3333

3434
VK_CHECK_COND(
35-
dim_at<Dim4D::Width>(in.sizes()) * dim_at<Dim4D::Width>(repeats) ==
36-
dim_at<Dim4D::Width>(out.sizes()),
35+
dim_at<DIM4D_WIDTH>(in.sizes()) * dim_at<DIM4D_WIDTH>(repeats) ==
36+
dim_at<DIM4D_WIDTH>(out.sizes()),
3737
"Output's width doesn't match input's width * repeat count");
3838

3939
VK_CHECK_COND(
40-
dim_at<Dim4D::Height>(in.sizes()) * dim_at<Dim4D::Height>(repeats) ==
41-
dim_at<Dim4D::Height>(out.sizes()),
40+
dim_at<DIM4D_HEIGHT>(in.sizes()) * dim_at<DIM4D_HEIGHT>(repeats) ==
41+
dim_at<DIM4D_HEIGHT>(out.sizes()),
4242
"Output's height doesn't match input's height * repeat count");
4343

4444
VK_CHECK_COND(
45-
dim_at<Dim4D::Channel>(in.sizes()) * dim_at<Dim4D::Channel>(repeats) ==
46-
dim_at<Dim4D::Channel>(out.sizes()),
45+
dim_at<DIM4D_CHANNEL>(in.sizes()) * dim_at<DIM4D_CHANNEL>(repeats) ==
46+
dim_at<DIM4D_CHANNEL>(out.sizes()),
4747
"Output's channel doesn't match input's channel * repeat count");
4848

4949
VK_CHECK_COND(
50-
dim_at<Dim4D::Batch>(in.sizes()) * dim_at<Dim4D::Batch>(repeats) ==
51-
dim_at<Dim4D::Batch>(out.sizes()),
50+
dim_at<DIM4D_BATCH>(in.sizes()) * dim_at<DIM4D_BATCH>(repeats) ==
51+
dim_at<DIM4D_BATCH>(out.sizes()),
5252
"Output's batch doesn't match input's batch * repeat count");
5353
}
5454

@@ -70,13 +70,13 @@ void add_repeat_channel_node(
7070
const std::vector<int64_t>& in_sizes = t_in->sizes();
7171

7272
int32_t in_width =
73-
api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Width>(in_sizes));
73+
api::utils::safe_downcast<int32_t>(dim_at<DIM4D_WIDTH>(in_sizes));
7474
int32_t in_height =
75-
api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Height>(in_sizes));
75+
api::utils::safe_downcast<int32_t>(dim_at<DIM4D_HEIGHT>(in_sizes));
7676
int32_t in_channel =
77-
api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Channel>(in_sizes));
77+
api::utils::safe_downcast<int32_t>(dim_at<DIM4D_CHANNEL>(in_sizes));
7878
int32_t in_batch =
79-
api::utils::safe_downcast<int32_t>(dim_at<Dim4D::Batch>(in_sizes));
79+
api::utils::safe_downcast<int32_t>(dim_at<DIM4D_BATCH>(in_sizes));
8080

8181
int32_t out_channel = repeat_channel * in_channel;
8282

@@ -142,7 +142,7 @@ void add_repeat_node(
142142
// dimension, we copy over the input texure to the output. In subsequent
143143
// dimensions, we read and write from the same tensor.
144144

145-
if (int64_t channel_repeat = dim_at<Dim4D::Channel>(repeats);
145+
if (int64_t channel_repeat = dim_at<DIM4D_CHANNEL>(repeats);
146146
channel_repeat == 1) {
147147
// If no repeat, short-cut to a direct copy
148148
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
@@ -156,12 +156,12 @@ void add_repeat_node(
156156

157157
// TODO: refactor width, height, and batch into a common helper function.
158158
// Width
159-
if (int64_t width_repeat = dim_at<Dim4D::Width>(repeats); width_repeat > 1) {
159+
if (int64_t width_repeat = dim_at<DIM4D_WIDTH>(repeats); width_repeat > 1) {
160160
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
161161

162162
for (int i = 1; i < width_repeat; ++i) {
163163
api::utils::ivec3 dst_offset = api::utils::make_ivec3(
164-
{i * dim_at<Dim4D::Width>(in_sizes), 0, 0}, false);
164+
{i * dim_at<DIM4D_WIDTH>(in_sizes), 0, 0}, false);
165165

166166
add_copy_offset_node(
167167
graph, out, running_range, src_offset, dst_offset, out);
@@ -171,13 +171,13 @@ void add_repeat_node(
171171
}
172172

173173
// Height
174-
if (int64_t height_repeat = dim_at<Dim4D::Height>(repeats);
174+
if (int64_t height_repeat = dim_at<DIM4D_HEIGHT>(repeats);
175175
height_repeat > 1) {
176176
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
177177

178178
for (int i = 1; i < height_repeat; ++i) {
179179
api::utils::ivec3 dst_offset = api::utils::make_ivec3(
180-
{0, i * dim_at<Dim4D::Height>(in_sizes), 0}, false);
180+
{0, i * dim_at<DIM4D_HEIGHT>(in_sizes), 0}, false);
181181

182182
add_copy_offset_node(
183183
graph, out, running_range, src_offset, dst_offset, out);
@@ -187,7 +187,7 @@ void add_repeat_node(
187187
}
188188

189189
// Batch
190-
if (int64_t batch_repeat = dim_at<Dim4D::Batch>(repeats); batch_repeat > 1) {
190+
if (int64_t batch_repeat = dim_at<DIM4D_BATCH>(repeats); batch_repeat > 1) {
191191
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
192192

193193
for (int i = 1; i < batch_repeat; ++i) {

backends/vulkan/runtime/graph/ops/impl/Slice.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ void add_slice_tensor_out_node(
4343

4444
dim = normalize(dim, t_in->dim());
4545

46-
// Create a dim value as in the underlying dim is 4-dimension.
47-
int64_t nchw_dim = dim + (4 - t_in->dim());
46+
Dim4DType dim4d = normalize_to_dim4d(*t_in, dim);
4847

4948
std::optional<int64_t> opt_start =
5049
graph.extract_optional_scalar<int64_t>(opt_start_ref);
@@ -61,7 +60,7 @@ void add_slice_tensor_out_node(
6160
VK_CHECK_COND((0 <= start) && (start < in_sizes[dim]));
6261
VK_CHECK_COND((0 <= end) && (end <= in_sizes[dim]));
6362

64-
if (nchw_dim == 1) {
63+
if (dim4d == DIM4D_CHANNEL) {
6564
// slice by channel
6665
std::string kernel_name = "slice_channel";
6766
kernel_name.reserve(kShaderNameReserve);
@@ -93,17 +92,17 @@ void add_slice_tensor_out_node(
9392
// GPU's coordinate is in x, y, z
9493
int64_t gpu_dim = -1;
9594
int64_t stride = 1;
96-
if (nchw_dim == 3) {
95+
if (dim4d == DIM4D_WIDTH) {
9796
gpu_dim = 0; // width: x dimension in gpu
9897
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
99-
} else if (nchw_dim == 2) {
98+
} else if (dim4d == DIM4D_HEIGHT) {
10099
gpu_dim = 1; // height: y dimension
101100
VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step));
102-
} else if (nchw_dim == 0) {
101+
} else if (dim4d == DIM4D_BATCH) {
103102
gpu_dim = 2; // batch: z dimension
104103

105104
// Due to channel packing, each batch value is span over stride planes
106-
int64_t n_channels = dim_at<Dim4D::Channel>(in_sizes);
105+
int64_t n_channels = dim_at(in_sizes, DIM4D_CHANNEL);
107106
stride = api::utils::div_up<int64_t>(n_channels, 4ll);
108107
} else {
109108
VK_THROW("Unexpected ncwh_dim!");

backends/vulkan/runtime/graph/ops/impl/Split.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ void add_split_with_sizes_default_node(
2929

3030
ValueListPtr out_list = graph.get_value_list(out_list_ref);
3131

32-
NchwDim nchw_dim = normalize_to_nchw_dim(*t_in, dim);
32+
Dim4DType dim4d = normalize_to_dim4d(*t_in, dim);
3333

3434
VK_CHECK_COND(out_list->size() == split_sizes.size());
3535

@@ -39,10 +39,10 @@ void add_split_with_sizes_default_node(
3939

4040
vTensorPtr t_out = graph.get_tensor(out_ref);
4141
VK_CHECK_COND(check_memory_layout_is(*t_out, api::kChannelsPacked));
42-
VK_CHECK_COND(dim_at(*t_out, nchw_dim) == split_size);
42+
VK_CHECK_COND(dim_at(*t_out, dim4d) == split_size);
4343
}
4444

45-
if (nchw_dim == DimWidth) {
45+
if (dim4d == DIM4D_WIDTH) {
4646
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
4747
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
4848

@@ -55,7 +55,7 @@ void add_split_with_sizes_default_node(
5555

5656
src_offset.data[0] += range.data[0];
5757
}
58-
} else if (nchw_dim == DimHeight) {
58+
} else if (dim4d == DIM4D_HEIGHT) {
5959
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
6060
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
6161

@@ -66,7 +66,7 @@ void add_split_with_sizes_default_node(
6666

6767
src_offset.data[1] += range.data[1];
6868
}
69-
} else if (nchw_dim == DimBatch) {
69+
} else if (dim4d == DIM4D_BATCH) {
7070
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
7171
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
7272

@@ -77,13 +77,13 @@ void add_split_with_sizes_default_node(
7777

7878
src_offset.data[2] += range.data[2];
7979
}
80-
} else if (nchw_dim == DimChannel) {
80+
} else if (dim4d == DIM4D_CHANNEL) {
8181
int32_t src_offset = 0;
8282
int32_t dst_offset = 0;
8383

8484
for (ValueRef out_ref : *out_list) {
8585
vTensorPtr t_out = graph.get_tensor(out_ref);
86-
int32_t range = dim_at<Dim4D::Channel>(t_out->sizes());
86+
int32_t range = dim_at<DIM4D_CHANNEL>(t_out->sizes());
8787
add_copy_channel_offset_node(
8888
graph, in, range, src_offset, dst_offset, out_ref);
8989
src_offset += range;
@@ -122,8 +122,8 @@ void add_split_tensor_node(
122122
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
123123

124124
vTensorPtr t_in = graph.get_tensor(in);
125-
NchwDim nchw_dim = normalize_to_nchw_dim(*t_in, dim);
126-
int64_t size = dim_at(*t_in, nchw_dim);
125+
Dim4DType dim4d = normalize_to_dim4d(*t_in, dim);
126+
int64_t size = dim_at(*t_in, dim4d);
127127
std::vector<int64_t> split_sizes(size / split_size, split_size);
128128

129129
add_split_with_sizes_default_node(graph, in, split_sizes, dim, out);

0 commit comments

Comments
 (0)