8
8
import unittest
9
9
from typing import Tuple
10
10
11
- import executorch .exir as exir
12
11
import torch
13
12
14
- # import the vulkan backend implementation
13
+ from executorch . backends . vulkan . partitioner . vulkan_partitioner import VulkanPartitioner
15
14
from executorch .backends .vulkan .vulkan_preprocess import VulkanBackend
16
15
17
- from executorch .exir import ExecutorchProgram
18
- from executorch . exir . backend . backend_api import to_backend
16
+ from executorch .exir import EdgeProgramManager , to_edge
17
+ from torch . export import export , ExportedProgram
19
18
20
19
ctypes .CDLL ("libvulkan.so.1" )
21
20
@@ -51,7 +50,7 @@ def assert_outputs_equal(self, model_output, ref_output, atol=1e-03, rtol=1e-03)
51
50
52
51
def lower_module_and_test_output (
53
52
self ,
54
- module : torch .nn .Module ,
53
+ model : torch .nn .Module ,
55
54
sample_inputs : Tuple [torch .Tensor ],
56
55
atol = 1e-03 ,
57
56
rtol = 1e-01 ,
@@ -61,36 +60,23 @@ def lower_module_and_test_output(
61
60
the given sample inputs. It then runs the lowered module and compares its
62
61
outputs with the outputs of the eager module.
63
62
"""
64
- edgeir_m = exir .capture (module , sample_inputs , exir .CaptureConfig ()).to_edge ()
65
- lowered_module = to_backend ("VulkanBackend" , edgeir_m .exported_program , [])
63
+ program : ExportedProgram = export (model , sample_inputs )
64
+ edge_program : EdgeProgramManager = to_edge (program )
65
+ edge_program = edge_program .to_backend (VulkanPartitioner ())
66
66
67
- class WrappedModule (torch .nn .Module ):
68
- def __init__ (self ):
69
- super ().__init__ ()
70
- self .one_module = lowered_module
71
-
72
- def forward (self , * args ):
73
- return self .one_module (* args )
67
+ executorch_program = edge_program .to_executorch ()
74
68
75
- executorch_program : ExecutorchProgram = (
76
- exir .capture (WrappedModule (), sample_inputs , exir .CaptureConfig ())
77
- .to_edge ()
78
- .to_executorch ()
79
- )
80
-
81
- # Assert the backend name is vulkan
82
69
self .assertEqual (
83
- executorch_program .program .execution_plan [0 ].delegates [0 ].id ,
70
+ executorch_program .executorch_program .execution_plan [0 ].delegates [0 ].id ,
84
71
VulkanBackend .__name__ ,
85
72
)
86
73
87
- # Test the model with executor
88
74
executorch_module = _load_for_executorch_from_buffer (executorch_program .buffer )
89
75
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
90
76
inputs_flattened , _ = tree_flatten (sample_inputs )
91
77
92
78
model_output = executorch_module .run_method ("forward" , tuple (inputs_flattened ))
93
- ref_output = module (* sample_inputs )
79
+ ref_output = model (* sample_inputs )
94
80
95
81
self .assert_outputs_equal (model_output , ref_output , atol = atol , rtol = rtol )
96
82
@@ -192,26 +178,6 @@ def forward(self, x, y):
192
178
193
179
self .lower_module_and_test_output (div_module , model_inputs )
194
180
195
- def test_vulkan_backend_floor_div (self ):
196
- class FloorDivModule (torch .nn .Module ):
197
- def __init__ (self ):
198
- super ().__init__ ()
199
-
200
- def forward (self , x , y ):
201
- z = x // y
202
- return z
203
-
204
- floor_div_module = FloorDivModule ()
205
- model_inputs = (
206
- torch .rand (size = (2 , 3 ), dtype = torch .float32 ) * 10.0 ,
207
- torch .rand (size = (2 , 3 ), dtype = torch .float32 ) + 1.0 ,
208
- )
209
-
210
- # absolute tolerance is 1 because of flooring
211
- self .lower_module_and_test_output (
212
- floor_div_module , model_inputs , atol = 1.0 + 1e-03
213
- )
214
-
215
181
def test_vulkan_backend_arithmetic (self ):
216
182
class ArithmeticModule (torch .nn .Module ):
217
183
def __init__ (self ):
@@ -249,3 +215,23 @@ def forward(self, x, y):
249
215
)
250
216
251
217
self .lower_module_and_test_output (pow_module , model_inputs )
218
+
219
+ def test_vulkan_backend_partial (self ):
220
+ class SimpleModel (torch .nn .Module ):
221
+ def __init__ (self ):
222
+ super ().__init__ ()
223
+ self .linear = torch .nn .Linear (10 , 10 )
224
+ self .offset_1 = self .weight = torch .rand (
225
+ size = (2 , 10 ), dtype = torch .float32
226
+ )
227
+ self .offset_2 = self .weight = torch .rand (
228
+ size = (2 , 10 ), dtype = torch .float32
229
+ )
230
+
231
+ def forward (self , x ):
232
+ return self .linear (x + self .offset_1 ) - self .offset_2
233
+
234
+ model = SimpleModel ()
235
+ model_inputs = (torch .rand (size = (2 , 10 ), dtype = torch .float32 ),)
236
+
237
+ self .lower_module_and_test_output (model , model_inputs )
0 commit comments