Skip to content

Commit 25a94ef

Browse files
authored
Qualcomm AI Engine Direct - Add unit test for Spill-Fill buffer (#7518)
Add unit test to validate the size of the Spill-Fill buffer.
1 parent 84e377a commit 25a94ef

File tree

3 files changed

+54
-3
lines changed

3 files changed

+54
-3
lines changed

backends/qualcomm/tests/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,18 @@ def forward(self, input_pos, k_val):
596596
return k_out
597597

598598

599+
class LargeTensorLinear(torch.nn.Module):
600+
def __init__(self):
601+
super().__init__()
602+
hidden_dim = 4096
603+
self.linear1 = torch.nn.Linear(512, hidden_dim)
604+
self.linear2 = torch.nn.Linear(hidden_dim, 512)
605+
606+
def forward(self, x):
607+
x1 = self.linear1(x) + self.linear1(x)
608+
return self.linear2(x1)
609+
610+
599611
class LayerNorm(torch.nn.Module):
600612
def __init__(self):
601613
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,24 @@ def test_qnn_backend_skip_node_op(self):
15811581
skip_node_op_set={"aten.add.Tensor"},
15821582
)
15831583

1584+
def test_qnn_backend_spill_fill_buffer_size(self):
1585+
module = LargeTensorLinear() # noqa: F405
1586+
sample_input = (torch.randn(1, 256, 512),)
1587+
edge_prog = capture_program(module, sample_input)
1588+
1589+
backend_options = generate_htp_compiler_spec(
1590+
use_fp16=True,
1591+
use_multi_contexts=True,
1592+
)
1593+
compiler_specs = generate_qnn_executorch_compiler_spec(
1594+
soc_model=self.chipset_table[TestQNN.model],
1595+
backend_options=backend_options,
1596+
)
1597+
partitioner = QnnPartitioner(compiler_specs)
1598+
edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner)
1599+
max_sf_size = update_spill_fill_size(edge_prog.exported_program)
1600+
self.assertNotEqual(0, max_sf_size)
1601+
15841602
def test_qnn_backend_multi_contexts(self):
15851603
module = SimpleModel() # noqa: F405
15861604
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
@@ -2011,6 +2029,25 @@ def calibrator(gm):
20112029
).to_executorch()
20122030
self.verify_output(module, sample_input, exec_prog)
20132031

2032+
def test_qnn_backend_spill_fill_buffer_size(self):
2033+
module = LargeTensorLinear() # noqa: F405
2034+
sample_input = (torch.randn(1, 256, 512),)
2035+
module = self.get_qdq_module(module, sample_input)
2036+
edge_prog = capture_program(module, sample_input)
2037+
2038+
backend_options = generate_htp_compiler_spec(
2039+
use_fp16=False,
2040+
use_multi_contexts=True,
2041+
)
2042+
compiler_specs = generate_qnn_executorch_compiler_spec(
2043+
soc_model=self.chipset_table[TestQNN.model],
2044+
backend_options=backend_options,
2045+
)
2046+
partitioner = QnnPartitioner(compiler_specs)
2047+
edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner)
2048+
max_sf_size = update_spill_fill_size(edge_prog.exported_program)
2049+
self.assertNotEqual(0, max_sf_size)
2050+
20142051
def test_qnn_backend_graph_level_mixed_precision(self):
20152052
module = SimpleModel() # noqa: F405
20162053
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))

backends/qualcomm/utils/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,15 +269,17 @@ def set_spec(module, options):
269269
options.backend_options.htp_options.max_sf_buf_size = max_sf_buf_size
270270
set_spec(module, options)
271271

272+
max_sf_size, modules_map = 0, {}
272273
if isinstance(exported_program, list):
273-
max_sf_size, modules_map = 0, {}
274274
for prog in exported_program:
275275
max_sf_buf_size, module_map = get_program_info(prog)
276276
max_sf_size = max(max_sf_size, max_sf_buf_size)
277277
modules_map.update(module_map)
278-
update_program(max_sf_size, modules_map)
279278
else:
280-
update_program(*get_program_info(exported_program))
279+
max_sf_size, module_map = get_program_info(exported_program)
280+
update_program(max_sf_size, module_map)
281+
282+
return max_sf_size
281283

282284

283285
def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:

0 commit comments

Comments
 (0)