Skip to content

Add op: enable transposed convolution #4197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 140 additions & 58 deletions kernels/portable/cpu/op_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ void conv2d_impl(
StridesArrayRef out_strides,
const size_t batch,
const size_t group,
const size_t out_c) {
const size_t out_c,
bool transposed) {
size_t in_C = in_sizes[1];
size_t out_C = out_sizes[1];

size_t out_H = out_sizes[2];
size_t in_H = in_sizes[2];
Expand All @@ -64,13 +66,15 @@ void conv2d_impl(
size_t in_C_per_group = in_C / groups;
size_t in_c_start = group * in_C_per_group;

size_t out_C_per_group = out_C / groups;
size_t out_c_start = group * out_C_per_group;

exec_aten::SizesType in_coord[kTensorDimensionLimit];
in_coord[0] = batch;
exec_aten::SizesType out_coord[kTensorDimensionLimit];
out_coord[0] = batch;
out_coord[1] = out_c;
exec_aten::SizesType w_coord[kTensorDimensionLimit];
w_coord[0] = out_c;

const int64_t stride_y = val_at(stride, 0);
const int64_t padding_y = val_at(padding, 0, /*default_value=*/0);
Expand All @@ -79,53 +83,115 @@ void conv2d_impl(
const int64_t padding_x = val_at(padding, 1, /*default_value=*/0);
const int64_t dilation_x = val_at(dilation, 1);

// Compute 2D output region
for (size_t out_y = 0; out_y < out_H; ++out_y) {
out_coord[2] = out_y;
for (size_t out_x = 0; out_x < out_W; ++out_x) {
out_coord[3] = out_x;

CTYPE accum = 0.0f;
for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
++in_c) {
in_coord[1] = in_c;
w_coord[1] = in_c - in_c_start;

for (size_t w_y = 0; w_y < w_H; ++w_y) {
w_coord[2] = w_y;

size_t in_y = stride_y * out_y + dilation_y * w_y - padding_y;
in_coord[2] = in_y;
// Only proceed if input y coordinate is within bounds
if (in_y >= 0 && in_y < in_H) {
for (size_t w_x = 0; w_x < w_W; ++w_x) {
w_coord[3] = w_x;

size_t in_x = stride_x * out_x + dilation_x * w_x - padding_x;
in_coord[3] = in_x;

// Only proceed if input coordinate is within bounds
if (in_x >= 0 && in_x < in_W) {
size_t in_idx =
calculate_linear_index(in_coord, in_strides.data(), 4);
CTYPE in_val = in_ptr[in_idx];

size_t w_idx =
calculate_linear_index(w_coord, w_strides.data(), 4);
CTYPE w_val = w_ptr[w_idx];

accum += in_val * w_val;
if (!transposed) {
w_coord[0] = out_c;
// Compute 2D output region
for (size_t out_y = 0; out_y < out_H; ++out_y) {
out_coord[2] = out_y;
for (size_t out_x = 0; out_x < out_W; ++out_x) {
out_coord[3] = out_x;

CTYPE accum = 0.0f;
for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
++in_c) {
in_coord[1] = in_c;
w_coord[1] = in_c - in_c_start;

for (size_t w_y = 0; w_y < w_H; ++w_y) {
w_coord[2] = w_y;

size_t in_y = stride_y * out_y + dilation_y * w_y - padding_y;
in_coord[2] = in_y;
// Only proceed if input y coordinate is within bounds
if (in_y >= 0 && in_y < in_H) {
for (size_t w_x = 0; w_x < w_W; ++w_x) {
w_coord[3] = w_x;

size_t in_x = stride_x * out_x + dilation_x * w_x - padding_x;
in_coord[3] = in_x;

// Only proceed if input x coordinate is within bounds
if (in_x >= 0 && in_x < in_W) {
size_t in_idx =
calculate_linear_index(in_coord, in_strides.data(), 4);
CTYPE in_val = in_ptr[in_idx];

size_t w_idx =
calculate_linear_index(w_coord, w_strides.data(), 4);
CTYPE w_val = w_ptr[w_idx];

accum += in_val * w_val;
}
}
}
}
}

if (bias_ptr != nullptr) {
accum += convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
}
size_t out_idx =
calculate_linear_index(out_coord, out_strides.data(), 4);
out_ptr[out_idx] = accum;
}
}
} else { // transposed convolution
if (bias_ptr != nullptr) {
out_coord[2] = 0;
out_coord[3] = 0;
size_t out_c_start_idx =
calculate_linear_index(out_coord, out_strides.data(), 4);
size_t out_c_end_idx = out_c_start_idx + out_H * out_W;
for (size_t out_ix = out_c_start_idx; out_ix < out_c_end_idx; out_ix++) {
out_ptr[out_ix] = convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
}
}

w_coord[1] = out_c - out_c_start;

for (size_t in_y = 0; in_y < in_H; ++in_y) {
in_coord[2] = in_y;

for (size_t in_x = 0; in_x < in_W; ++in_x) {
in_coord[3] = in_x;

for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
++in_c) {
in_coord[1] = in_c;

size_t in_idx =
calculate_linear_index(in_coord, in_strides.data(), 4);
CTYPE in_val = in_ptr[in_idx];

w_coord[0] = in_c;
for (size_t w_y = 0; w_y < w_H; ++w_y) {
w_coord[2] = w_y;
size_t out_y = stride_y * in_y + dilation_y * w_y - padding_y;
out_coord[2] = out_y;

// Only proceed if output y coordinate is within bounds
if (out_y >= 0 && out_y < out_H) {
for (size_t w_x = 0; w_x < w_W; ++w_x) {
w_coord[3] = w_x;
size_t out_x = stride_x * in_x + dilation_x * w_x - padding_x;
out_coord[3] = out_x;

if (bias_ptr != nullptr) {
accum += convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
// Only proceed if output x coordinate is within bounds
if (out_x >= 0 && out_x < out_W) {
size_t w_idx =
calculate_linear_index(w_coord, w_strides.data(), 4);
CTYPE w_val = w_ptr[w_idx];

size_t out_idx =
calculate_linear_index(out_coord, out_strides.data(), 4);

out_ptr[out_idx] += in_val * w_val;
}
}
}
}
}
}
size_t out_idx = calculate_linear_index(out_coord, out_strides.data(), 4);
out_ptr[out_idx] = accum;
}
}
}
Expand All @@ -138,14 +204,9 @@ void convolution_wrapper(
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool transposed,
int64_t groups,
Tensor& out) {
size_t out_N = in.size(0);
size_t out_C = weight.size(0);

// Compute the number of in and out channels in each group
size_t out_C_per_group = out_C / groups;

SizesArrayRef in_sizes = in.sizes();
SizesArrayRef weight_sizes = weight.sizes();
SizesArrayRef out_sizes = out.sizes();
Expand Down Expand Up @@ -233,6 +294,15 @@ void convolution_wrapper(
const CTYPE_BIAS* const bias_ptr =
bias.has_value() ? bias.value().const_data_ptr<CTYPE_BIAS>() : nullptr;

size_t out_N = out.size(0);
size_t out_C_per_group = out.size(1) / groups;

if (transposed && bias_ptr == nullptr) {
// If bias is not present, we need to initialize the output to 0
// before we can accumulate into it.
memset(out_ptr, 0, out.nbytes());
}

for (size_t batch = 0; batch < out_N; ++batch) {
for (size_t group = 0; group < groups; ++group) {
// Align channel offset based on the group
Expand All @@ -257,7 +327,8 @@ void convolution_wrapper(
{out_strides, 4},
batch,
group,
out_c);
out_c,
transposed);
}
}
}
Expand All @@ -273,8 +344,8 @@ Tensor& convolution_out(
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
__ET_UNUSED bool transposed,
__ET_UNUSED IntArrayRef output_padding,
bool transposed,
IntArrayRef output_padding,
int64_t groups,
Tensor& out) {
(void)ctx;
Expand All @@ -298,7 +369,16 @@ Tensor& convolution_out(
size_t output_ndim = 0;
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
get_convolution_out_target_size(
in, weight, stride, padding, dilation, output_sizes, &output_ndim);
in,
weight,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_sizes,
&output_ndim);

ET_KERNEL_CHECK(
ctx,
Expand All @@ -321,12 +401,14 @@ Tensor& convolution_out(
if (bias.has_value()) {
bias_type = bias.value().scalar_type();
}
ET_SWITCH_REAL_TYPES(in_type, ctx, "convolution.out", CTYPE, [&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool, bias_type, ctx, "convolution.out", CTYPE_BIAS, [&]() {
convolution_wrapper<CTYPE, CTYPE_BIAS>(
in, weight, bias, stride, padding, dilation, groups, out);
});

constexpr auto name = "convolution.out";

ET_SWITCH_REALH_TYPES(in_type, ctx, name, CTYPE, [&]() {
ET_SWITCH_REALHB_TYPES(bias_type, ctx, name, CTYPE_BIAS, [&]() {
convolution_wrapper<CTYPE, CTYPE_BIAS>(
in, weight, bias, stride, padding, dilation, transposed, groups, out);
});
});

return out;
Expand Down
Loading
Loading