@@ -612,36 +612,35 @@ def gen_op_check_fn(self) -> str:
612
612
613
613
test_fixture_template = """
614
614
class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<at::ScalarType, api::StorageType, api::GPUMemoryLayout>> {{
615
- protected:
616
- ComputeGraph* graph;
617
- at::ScalarType test_dtype = at::kFloat;
618
- float rtol = {rtol};
619
- float atol = {atol};
620
-
621
- void SetUp() override {{
622
- GraphConfig config;
623
- api::StorageType default_storage_type;
624
- api::GPUMemoryLayout default_memory_layout;
625
- std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();
626
- config.set_storage_type_override(default_storage_type);
627
- config.set_memory_layout_override(default_memory_layout);
628
- graph = new ComputeGraph(config);
629
-
630
- if (test_dtype == at::kHalf) {{
631
- rtol = 1e-2;
632
- atol = 1e-2;
633
- }}
615
+ protected:
616
+ ComputeGraph* graph;
617
+ at::ScalarType test_dtype = at::kFloat;
618
+ float rtol = {rtol};
619
+ float atol = {atol};
620
+
621
+ void SetUp() override {{
622
+ GraphConfig config;
623
+ api::StorageType default_storage_type;
624
+ api::GPUMemoryLayout default_memory_layout;
625
+ std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();
626
+ config.set_storage_type_override(default_storage_type);
627
+ config.set_memory_layout_override(default_memory_layout);
628
+ graph = new ComputeGraph(config);
629
+
630
+ if (test_dtype == at::kHalf) {{
631
+ rtol = 1e-2;
632
+ atol = 1e-2;
634
633
}}
634
+ }}
635
635
636
- void TearDown() override {{
637
- delete graph;
638
- graph = nullptr;
639
- }}
640
-
641
- {check_fn}
636
+ void TearDown() override {{
637
+ delete graph;
638
+ graph = nullptr;
639
+ }}
642
640
643
- {prepacked_check_fn }
641
+ {check_fn }
644
642
643
+ {prepacked_check_fn}
645
644
}};
646
645
"""
647
646
@@ -676,13 +675,13 @@ def gen_parameterization(self) -> str:
676
675
layouts = self .suite_def .layouts
677
676
678
677
return f"""
679
- INSTANTIATE_TEST_SUITE_P(
680
- Combos_{ self .op_name } ,
681
- GeneratedOpsTest_{ self .op_name } ,
682
- ::testing::Combine(
683
- ::testing::Values({ ', ' .join (dtypes )} ),
684
- ::testing::Values({ ', ' .join (storage_types )} ),
685
- ::testing::Values({ ', ' .join (layouts )} )));
678
+ INSTANTIATE_TEST_SUITE_P(
679
+ Combos_{ self .op_name } ,
680
+ GeneratedOpsTest_{ self .op_name } ,
681
+ ::testing::Combine(
682
+ ::testing::Values({ ', ' .join (dtypes )} ),
683
+ ::testing::Values({ ', ' .join (storage_types )} ),
684
+ ::testing::Values({ ', ' .join (layouts )} )));
686
685
"""
687
686
688
687
@@ -701,41 +700,41 @@ def gen_parameterization(self) -> str:
701
700
using TensorOptions = at::TensorOptions;
702
701
703
702
api::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
704
- switch(at_scalartype) {
705
- case c10::kFloat:
706
- return api::kFloat;
707
- case c10::kHalf:
708
- return api::kHalf;
709
- case c10::kInt:
710
- return api::kInt;
711
- case c10::kLong:
712
- return api::kInt;
713
- case c10::kChar:
714
- return api::kChar;
715
- default:
716
- VK_THROW("Unsupported at::ScalarType!");
717
- }
703
+ switch (at_scalartype) {
704
+ case c10::kFloat:
705
+ return api::kFloat;
706
+ case c10::kHalf:
707
+ return api::kHalf;
708
+ case c10::kInt:
709
+ return api::kInt;
710
+ case c10::kLong:
711
+ return api::kInt;
712
+ case c10::kChar:
713
+ return api::kChar;
714
+ default:
715
+ VK_THROW("Unsupported at::ScalarType!");
716
+ }
718
717
}
719
718
720
719
#ifdef USE_VULKAN_FP16_INFERENCE
721
720
bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-2, float atol=1e-2) {
722
721
#else
723
722
bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-5, float atol=1e-5) {
724
723
#endif
725
- // Skip checking index tensors
726
- if (t1.scalar_type() == at::kLong || t2.scalar_type() == at::kLong) {
727
- return true;
728
- }
729
- bool is_close = at::allclose(t1, t2, rtol, atol);
730
- if (!is_close && t1.numel() < 500) {
731
- std::cout << "reference: " << std::endl;
732
- print(t1, 150);
733
- std::cout << std::endl;
734
- std::cout << "vulkan: " << std::endl;
735
- print(t2, 150);
736
- std::cout << std::endl;
737
- }
738
- return is_close;
724
+ // Skip checking index tensors
725
+ if (t1.scalar_type() == at::kLong || t2.scalar_type() == at::kLong) {
726
+ return true;
727
+ }
728
+ bool is_close = at::allclose(t1, t2, rtol, atol);
729
+ if (!is_close && t1.numel() < 500) {
730
+ std::cout << "reference: " << std::endl;
731
+ print(t1, 150);
732
+ std::cout << std::endl;
733
+ std::cout << "vulkan: " << std::endl;
734
+ print(t2, 150);
735
+ std::cout << std::endl;
736
+ }
737
+ return is_close;
739
738
}
740
739
"""
741
740
0 commit comments