3
3
#include " gtest/gtest.h"
4
4
#include " tests/util/util.h"
5
5
#include " torch/csrc/jit/ir/irparser.h"
6
+ #include " torch/csrc/jit/runtime/jit_exception.h"
6
7
#include " torch/torch.h"
7
8
8
9
TEST (Evaluators, DivIntEvaluatesCorrectly) {
@@ -579,4 +580,89 @@ TEST(Evaluators, AndBoolResultIsFalseEvaluatesCorrectly) {
579
580
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph (g->block (), {});
580
581
581
582
ASSERT_TRUE (jit_results[0 ] == trt_results[0 ]);
583
+ }
584
+
585
+ TEST (Evaluators, AtenFormatEvaluatesCorrectly) {
586
+ const auto graph = R"IR(
587
+ graph(%x_1 : Tensor, %x_2 : Tensor):
588
+ %0 : int = prim::Constant[value=1]()
589
+ %1 : str = prim::Constant[value="res{}_{}_"]()
590
+ %2 : int = prim::Constant[value=5]()
591
+ %2.1 : int = prim::Constant[value=2]()
592
+ %3 : str = prim::Constant[value="res5_2_"]()
593
+ %4 : str = aten::format(%1, %2, %2.1)
594
+ %5 : bool = aten::eq(%3, %4)
595
+ %y : Tensor = prim::If(%5)
596
+ block0():
597
+ %194 : Tensor = aten::add(%x_1, %x_2, %0)
598
+ -> (%194)
599
+ block1():
600
+ %195 : Tensor = aten::sub(%x_1, %x_2, %0)
601
+ -> (%195)
602
+ return (%y))IR" ;
603
+ auto g = std::make_shared<torch::jit::Graph>();
604
+ torch::jit::parseIR (graph, &*g);
605
+
606
+ auto in0 = at::randint (1 , 10 , {3 , 4 }, {at::kCUDA });
607
+ auto in1 = in0.clone ();
608
+
609
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
610
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in0, in1});
611
+
612
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
613
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in0, in1});
614
+
615
+ ASSERT_TRUE (
616
+ torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
617
+ }
618
+
619
+ TEST (Evaluators, AtenFormatRaiseExceptionEvaluatesCorrectly) {
620
+ const auto graph = R"IR(
621
+ graph(%x_1 : Tensor, %x_2 : Tensor):
622
+ %0 : int = prim::Constant[value=1]()
623
+ %1 : str = prim::Constant[value="res5_1"]()
624
+ %2 : str = prim::Constant[value="{} is not equal to {}"]()
625
+ %3 : str = prim::Constant[value="res5_2"]()
626
+ %5713 : Tensor = prim::Uninitialized()
627
+ %4 : str = aten::format(%2, %1, %3)
628
+ %5 : bool = aten::eq(%1, %3)
629
+ %y : Tensor = prim::If(%5)
630
+ block0():
631
+ %194 : Tensor = aten::add(%x_1, %x_2, %0)
632
+ -> (%194)
633
+ block1():
634
+ prim::RaiseException(%4)
635
+ -> (%5713)
636
+ return (%y))IR" ;
637
+ auto g = std::make_shared<torch::jit::Graph>();
638
+ torch::jit::parseIR (graph, &*g);
639
+
640
+ auto in0 = at::randint (1 , 10 , {3 , 4 }, {at::kCUDA });
641
+ auto in1 = in0.clone ();
642
+
643
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
644
+ std::vector<at::Tensor> jit_results, trt_results;
645
+ std::string error_jit, error_torch_trt;
646
+ try {
647
+ jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in0, in1});
648
+ } catch (const torch::jit::JITException& error) {
649
+ error_jit = error.what ();
650
+ }
651
+
652
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
653
+ try {
654
+ trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in0, in1});
655
+ } catch (const torch_tensorrt::Error& error) {
656
+ error_torch_trt = error.what ();
657
+ }
658
+
659
+ auto position1 = error_jit.find (" RuntimeError:" );
660
+ auto position2 = error_torch_trt.find (" Error from TorchScript:" );
661
+ std::string jit_msg = error_jit.substr (position1 + 13 );
662
+ std::string torch_trt_msg = error_torch_trt.substr (position2 + 23 );
663
+ if (jit_msg == torch_trt_msg) {
664
+ ASSERT_TRUE (true );
665
+ } else {
666
+ ASSERT_TRUE (false );
667
+ }
582
668
}
0 commit comments