Skip to content

Commit e6648e9

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Add tests for zero-dim tensors
Summary: Turns out zero dim tensors don't need anything special to be enabled. Therefore just add test cases for them. Differential Revision: D57463151
1 parent 5c70121 commit e6648e9

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,19 @@ def forward(self, x, y, w):
202202

203203
self.lower_module_and_test_output(add_module, sample_inputs)
204204

205+
def test_vulkan_backend_zero_dim_tensor(self):
206+
class ZeroDimModule(torch.nn.Module):
207+
def __init__(self):
208+
super().__init__()
209+
self.zero = torch.full([], 1.3, dtype=torch.float32)
210+
211+
def forward(self, x):
212+
return x + self.zero
213+
214+
internal_data_module = ZeroDimModule()
215+
sample_inputs = (torch.rand(size=(2, 3), dtype=torch.float32),)
216+
self.lower_module_and_test_output(internal_data_module, sample_inputs)
217+
205218
def test_vulkan_backend_internal_data(self):
206219
class InternalDataModule(torch.nn.Module):
207220
def __init__(self):

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,51 @@ TEST(VulkanComputeGraphTest, test_values_string) {
627627
EXPECT_TRUE(stored == "hello, world");
628628
}
629629

630+
TEST(VulkanComputeGraphTest, test_zero_dim_tensor) {
631+
GraphConfig config;
632+
ComputeGraph graph(config);
633+
634+
std::vector<int64_t> size_big = {7, 3, 5};
635+
std::vector<int64_t> size_small = {};
636+
637+
// Build graph
638+
639+
IOValueRef a = graph.add_input_tensor(size_big, api::kFloat);
640+
IOValueRef b = graph.add_input_tensor(size_small, api::kFloat);
641+
642+
IOValueRef out = {};
643+
644+
out.value = graph.add_tensor(size_big, api::kFloat);
645+
646+
auto addFn = VK_GET_OP_FN("aten.add.Tensor");
647+
addFn(graph, {a.value, b.value, kDummyValueRef, out.value});
648+
649+
out.staging = graph.set_output_tensor(out.value);
650+
651+
graph.prepare();
652+
graph.encode_execute();
653+
654+
// Run graph
655+
656+
for (float i = 5.0f; i < 30.0f; i += 10.0f) {
657+
float val_a = i + 2.0f;
658+
float val_b = i + 1.5f;
659+
float val_c = val_a + val_b;
660+
661+
fill_vtensor(graph, a, val_a);
662+
fill_vtensor(graph, b, val_b);
663+
664+
graph.execute();
665+
666+
EXTRACT_TENSOR(out);
667+
668+
// Sanity check that the values are correct
669+
for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) {
670+
CHECK_VALUE(data_out, i, val_c);
671+
}
672+
}
673+
}
674+
630675
TEST(VulkanComputeGraphTest, test_simple_graph) {
631676
GraphConfig config;
632677
ComputeGraph graph(config);

0 commit comments

Comments
 (0)