@@ -704,7 +704,30 @@ def get_clone_inputs():
704
704
705
705
@register_test_suite ("aten.repeat.default" )
706
706
def get_repeat_inputs ():
707
- test_suite = VkTestSuite (
707
+ test_suite_2d = VkTestSuite (
708
+ [
709
+ # Repeat channels only (most challenging case)
710
+ ((3 , XS , S ), [2 , 1 , 1 ]),
711
+ ((7 , XS , S ), [4 , 1 , 1 ]),
712
+ # More other cases
713
+ ((2 , 3 ), [1 , 4 ]),
714
+ ((2 , 3 ), [4 , 1 ]),
715
+ ((2 , 3 ), [4 , 4 ]),
716
+ ((S1 , S2 , S2 ), [1 , 3 , 1 ]),
717
+ ((S1 , S2 , S2 ), [1 , 3 , 3 ]),
718
+ ((S1 , S2 , S2 ), [3 , 3 , 1 ]),
719
+ ((S1 , S2 , S2 ), [3 , 3 , 3 ]),
720
+ # Expanding cases
721
+ ((2 , 3 ), [3 , 1 , 4 ]),
722
+ ]
723
+ )
724
+ test_suite_2d .layouts = ["utils::kChannelsPacked" ]
725
+ test_suite_2d .storage_types = ["utils::kTexture2D" ]
726
+ test_suite_2d .data_gen = "make_seq_tensor"
727
+ test_suite_2d .dtypes = ["at::kFloat" ]
728
+ test_suite_2d .test_name_suffix = "2d"
729
+
730
+ test_suite_3d = VkTestSuite (
708
731
[
709
732
# Repeat channels only (most challenging case)
710
733
((3 , XS , S ), [2 , 1 , 1 ]),
@@ -739,13 +762,13 @@ def get_repeat_inputs():
739
762
((2 , 3 ), [3 , 3 , 2 , 4 ]),
740
763
]
741
764
)
742
- test_suite .layouts = [
743
- "utils::kChannelsPacked" ,
744
- ]
745
- test_suite . storage_types = ["utils::kTexture2D" , "utils::kTexture3D " ]
746
- test_suite . data_gen = "make_seq_tensor "
747
- test_suite . dtypes = [ "at::kFloat" ]
748
- return test_suite
765
+ test_suite_3d .layouts = ["utils::kChannelsPacked" ]
766
+ test_suite_3d . storage_types = [ "utils::kTexture3D" ]
767
+ test_suite_3d . data_gen = "make_seq_tensor"
768
+ test_suite_3d . dtypes = ["at::kFloat " ]
769
+ test_suite_2d . test_name_suffix = "3d "
770
+
771
+ return [ test_suite_2d , test_suite_3d ]
749
772
750
773
751
774
@register_test_suite ("aten.repeat_interleave.self_int" )
@@ -1164,7 +1187,7 @@ def get_squeeze_copy_dim_inputs():
1164
1187
1165
1188
@register_test_suite ("aten.flip.default" )
1166
1189
def get_flip_inputs ():
1167
- Test = namedtuple ("VkIndexSelectTest " , ["self" , "dim" ])
1190
+ Test = namedtuple ("Flip " , ["self" , "dim" ])
1168
1191
Test .__new__ .__defaults__ = (None , 0 )
1169
1192
1170
1193
test_cases = [
0 commit comments