Skip to content

Commit 101fac6

Browse files
authored
fix: Fix deconv kernel channel num_output_maps where wts are ITensor (#2678)
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 50206d5 commit 101fac6

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,12 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
139139
filter_dim.nbDims = nbSpatialDims;
140140
filter_dim.d[0] = kernel_dims.d[2];
141141
filter_dim.d[1] = kernel_dims.d[3];
142+
// For Conv2d layer, weights are in the shape of (out_channels, in_channels/groups,...)
142143
int32_t num_output_maps = kernel_dims.d[0];
144+
if (transposed) {
145+
// For ConvTranspose layer, weights are in the shape of (in_channels, out_channel/groups,...)
146+
num_output_maps = kernel_dims.d[1];
147+
}
143148
bool expand_dims = nbSpatialDims == 1;
144149
if (expand_dims) {
145150
// In case of Conv1D -> map it to 2D version
@@ -150,9 +155,6 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
150155
LOG_DEBUG("Reshaping input dimensions to: " << in->getDimensions());
151156
kernel = addPadding(ctx, n, kernel, 4, true, true, std::string(util::node_info(n) + "_kernel_shuffle"));
152157
LOG_DEBUG("Reshaping kernel dimensions to: " << kernel->getDimensions());
153-
if (transposed) {
154-
num_output_maps = kernel_dims.d[1];
155-
}
156158
}
157159

158160
// Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API.

tests/core/conversion/converters/test_conv_deconv.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,59 @@ TEST(Converters, ATenConvTransposeConvertsCorrectly) {
497497
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
498498
}
499499

500+
TEST(Converters, ATenConvTranspose2dWithWeightsAsTensorsConvertsCorrectly) {
501+
const auto graph = R"IR(
502+
graph(%0 : Tensor,
503+
%1 : Float(48, 56, 3, 3, strides=[504, 9, 3, 1])):
504+
%2 : int = prim::Constant[value=-128]()
505+
%3 : float = prim::Constant[value=3.5]()
506+
%4 : int = prim::Constant[value=0]()
507+
%5 : int = prim::Constant[value=127]()
508+
%quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5)
509+
%6 : int = prim::Constant[value=6]()
510+
%7 : int = prim::Constant[value=56]()
511+
%8 : Device = prim::Constant[value="cuda:0"]()
512+
%9 : None = prim::Constant()
513+
%10 : int[] = prim::ListConstruct(%7)
514+
%11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9)
515+
%12 : int[] = prim::ListConstruct(%7)
516+
%13 : int = prim::Constant[value=1]()
517+
%14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9)
518+
%quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5)
519+
%15 : None = prim::Constant()
520+
%16 : bool = prim::Constant[value=1]()
521+
%17 : int = prim::Constant[value=1]() # Adjusted padding
522+
%17.1: int = prim::Constant[value=0]() # Adjusted out_padding
523+
%18 : int = prim::Constant[value=1]() # Adjusted dilation
524+
%19 : int = prim::Constant[value=2]() # Adjusted stride
525+
%20 : int = prim::Constant[value=1]()
526+
%21 : int[] = prim::ListConstruct(%17)
527+
%22 : int[] = prim::ListConstruct(%17, %17)
528+
%23 : int[] = prim::ListConstruct(%18, %18)
529+
%23.1: int[] = prim::ListConstruct(%17.1, %17.1)
530+
%24 : int[] = prim::ListConstruct(%19, %19)
531+
%25 : Tensor = aten::_convolution(%quant_input, %quant_wts, %15, %24, %22, %23, %16, %23.1, %17, %16, %16, %16, %16)
532+
return (%25))IR";
533+
534+
auto g = std::make_shared<torch::jit::Graph>();
535+
torch::jit::parseIR(graph, g.get());
536+
537+
auto in = at::randint(1, 10, {1, 48, 2, 200}, {at::kCUDA});
538+
auto w = at::randint(1, 2, {48, 56, 3, 3}, {at::kCUDA});
539+
540+
auto jit_in = at::clone(in);
541+
auto jit_w = at::clone(w);
542+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
543+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in, jit_w});
544+
545+
auto trt_in = at::clone(in);
546+
auto trt_w = at::clone(w);
547+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
548+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in, trt_w}, nvinfer1::DataType::kINT8);
549+
550+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
551+
}
552+
500553
TEST(Converters, ATenConvTransposeNoBiasConvertsCorrectly) {
501554
const auto graph = R"IR(
502555
graph(%0 : Tensor,

0 commit comments

Comments
 (0)