@@ -497,6 +497,59 @@ TEST(Converters, ATenConvTransposeConvertsCorrectly) {
497
497
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
498
498
}
499
499
500
+ TEST (Converters, ATenConvTranspose2dWithWeightsAsTensorsConvertsCorrectly) {
501
+ const auto graph = R"IR(
502
+ graph(%0 : Tensor,
503
+ %1 : Float(48, 56, 3, 3, strides=[504, 9, 3, 1])):
504
+ %2 : int = prim::Constant[value=-128]()
505
+ %3 : float = prim::Constant[value=3.5]()
506
+ %4 : int = prim::Constant[value=0]()
507
+ %5 : int = prim::Constant[value=127]()
508
+ %quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%0, %3, %4, %2, %5)
509
+ %6 : int = prim::Constant[value=6]()
510
+ %7 : int = prim::Constant[value=56]()
511
+ %8 : Device = prim::Constant[value="cuda:0"]()
512
+ %9 : None = prim::Constant()
513
+ %10 : int[] = prim::ListConstruct(%7)
514
+ %11 : Tensor = aten::full(%10, %3, %6, %9, %8, %9)
515
+ %12 : int[] = prim::ListConstruct(%7)
516
+ %13 : int = prim::Constant[value=1]()
517
+ %14 : Tensor = aten::full(%12, %13, %6, %9, %8, %9)
518
+ %quant_wts : Tensor = aten::fake_quantize_per_channel_affine(%1, %11, %14, %13, %2, %5)
519
+ %15 : None = prim::Constant()
520
+ %16 : bool = prim::Constant[value=1]()
521
+ %17 : int = prim::Constant[value=1]() # Adjusted padding
522
+ %17.1: int = prim::Constant[value=0]() # Adjusted out_padding
523
+ %18 : int = prim::Constant[value=1]() # Adjusted dilation
524
+ %19 : int = prim::Constant[value=2]() # Adjusted stride
525
+ %20 : int = prim::Constant[value=1]()
526
+ %21 : int[] = prim::ListConstruct(%17)
527
+ %22 : int[] = prim::ListConstruct(%17, %17)
528
+ %23 : int[] = prim::ListConstruct(%18, %18)
529
+ %23.1: int[] = prim::ListConstruct(%17.1, %17.1)
530
+ %24 : int[] = prim::ListConstruct(%19, %19)
531
+ %25 : Tensor = aten::_convolution(%quant_input, %quant_wts, %15, %24, %22, %23, %16, %23.1, %17, %16, %16, %16, %16)
532
+ return (%25))IR" ;
533
+
534
+ auto g = std::make_shared<torch::jit::Graph>();
535
+ torch::jit::parseIR (graph, g.get ());
536
+
537
+ auto in = at::randint (1 , 10 , {1 , 48 , 2 , 200 }, {at::kCUDA });
538
+ auto w = at::randint (1 , 2 , {48 , 56 , 3 , 3 }, {at::kCUDA });
539
+
540
+ auto jit_in = at::clone (in);
541
+ auto jit_w = at::clone (w);
542
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
543
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_in, jit_w});
544
+
545
+ auto trt_in = at::clone (in);
546
+ auto trt_w = at::clone (w);
547
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
548
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_in, trt_w}, nvinfer1::DataType::kINT8 );
549
+
550
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
551
+ }
552
+
500
553
TEST (Converters, ATenConvTransposeNoBiasConvertsCorrectly) {
501
554
const auto graph = R"IR(
502
555
graph(%0 : Tensor,
0 commit comments