@@ -732,3 +732,69 @@ def forward(self, x):
732
732
sample_inputs ,
733
733
memory_layouts = [vk_graph_schema .VkMemoryLayout .TENSOR_CHANNELS_PACKED ],
734
734
)
735
+
736
+ def test_vulkan_backend_reshape (self ):
737
+ class ReshapeModule (torch .nn .Module ):
738
+ def __init__ (self ):
739
+ super ().__init__ ()
740
+
741
+ def forward (self , x ):
742
+ return torch .reshape (x , [- 1 , x .size (- 1 )])
743
+
744
+ sample_inputs = (torch .randn (size = (5 , 3 , 4 ), dtype = torch .float32 ),)
745
+
746
+ self .lower_module_and_test_output (
747
+ ReshapeModule (),
748
+ sample_inputs ,
749
+ memory_layouts = [vk_graph_schema .VkMemoryLayout .TENSOR_CHANNELS_PACKED ],
750
+ )
751
+
752
+ def test_vulkan_backend_view (self ):
753
+ class ViewModule (torch .nn .Module ):
754
+ def __init__ (self ):
755
+ super ().__init__ ()
756
+
757
+ def forward (self , x ):
758
+ return x .view ([- 1 , x .size (- 1 )])
759
+
760
+ sample_inputs = (torch .randn (size = (3 , 2 , 3 , 4 ), dtype = torch .float32 ),)
761
+
762
+ self .lower_module_and_test_output (
763
+ ViewModule (),
764
+ sample_inputs ,
765
+ memory_layouts = [vk_graph_schema .VkMemoryLayout .TENSOR_CHANNELS_PACKED ],
766
+ )
767
+
768
+ def test_vulkan_backend_unsqueeze (self ):
769
+ class UnsqueezeModule (torch .nn .Module ):
770
+ def __init__ (self ):
771
+ super ().__init__ ()
772
+
773
+ def forward (self , x ):
774
+ x = torch .unsqueeze (x , 1 )
775
+ x = torch .unsqueeze (x , 0 )
776
+ return x
777
+
778
+ sample_inputs = (torch .randn (size = (3 ,), dtype = torch .float32 ),)
779
+
780
+ self .lower_module_and_test_output (
781
+ UnsqueezeModule (),
782
+ sample_inputs ,
783
+ memory_layouts = [vk_graph_schema .VkMemoryLayout .TENSOR_CHANNELS_PACKED ],
784
+ )
785
+
786
+ def test_vulkan_backend_select (self ):
787
+ class SelectModule (torch .nn .Module ):
788
+ def __init__ (self ):
789
+ super ().__init__ ()
790
+
791
+ def forward (self , x ):
792
+ return x [0 ][3 ]
793
+
794
+ sample_inputs = (torch .randn (size = (3 , 6 , 2 , 7 ), dtype = torch .float32 ),)
795
+
796
+ self .lower_module_and_test_output (
797
+ SelectModule (),
798
+ sample_inputs ,
799
+ memory_layouts = [vk_graph_schema .VkMemoryLayout .TENSOR_CHANNELS_PACKED ],
800
+ )
0 commit comments