Skip to content

Commit 7c182ea

Browse files
committed
Update on "[8/n][ET-VK] Support staging any 8-bit texture"
Changes following from #4485 to support `texture2d` and support `uint8`, respectively. Differential Revision: [D63918659](https://our.internmc.facebook.com/intern/diff/D63918659/) [ghstack-poisoned]
2 parents 88ab798 + 6213b90 commit 7c182ea

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

backends/vulkan/runtime/graph/ops/impl/Flip.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ void add_flip_node(
6565
graph.create_local_wg_size(out),
6666
// Inputs and Outputs
6767
{
68-
{out, vkapi::MemoryAccessType::WRITE},
69-
{in, vkapi::MemoryAccessType::READ},
68+
{out, vkapi::kWrite},
69+
{in, vkapi::kRead},
7070
},
7171
// Parameter buffers
7272
{

backends/vulkan/test/op_tests/cases.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,30 @@ def get_clone_inputs():
704704

705705
@register_test_suite("aten.repeat.default")
706706
def get_repeat_inputs():
707-
test_suite = VkTestSuite(
707+
test_suite_2d = VkTestSuite(
708+
[
709+
# Repeat channels only (most challenging case)
710+
((3, XS, S), [2, 1, 1]),
711+
((7, XS, S), [4, 1, 1]),
712+
# More other cases
713+
((2, 3), [1, 4]),
714+
((2, 3), [4, 1]),
715+
((2, 3), [4, 4]),
716+
((S1, S2, S2), [1, 3, 1]),
717+
((S1, S2, S2), [1, 3, 3]),
718+
((S1, S2, S2), [3, 3, 1]),
719+
((S1, S2, S2), [3, 3, 3]),
720+
# Expanding cases
721+
((2, 3), [3, 1, 4]),
722+
]
723+
)
724+
test_suite_2d.layouts = ["utils::kChannelsPacked"]
725+
test_suite_2d.storage_types = ["utils::kTexture2D"]
726+
test_suite_2d.data_gen = "make_seq_tensor"
727+
test_suite_2d.dtypes = ["at::kFloat"]
728+
test_suite_2d.test_name_suffix = "2d"
729+
730+
test_suite_3d = VkTestSuite(
708731
[
709732
# Repeat channels only (most challenging case)
710733
((3, XS, S), [2, 1, 1]),
@@ -739,13 +762,13 @@ def get_repeat_inputs():
739762
((2, 3), [3, 3, 2, 4]),
740763
]
741764
)
742-
test_suite.layouts = [
743-
"utils::kChannelsPacked",
744-
]
745-
test_suite.storage_types = ["utils::kTexture2D", "utils::kTexture3D"]
746-
test_suite.data_gen = "make_seq_tensor"
747-
test_suite.dtypes = ["at::kFloat"]
748-
return test_suite
765+
test_suite_3d.layouts = ["utils::kChannelsPacked"]
766+
test_suite_3d.storage_types = ["utils::kTexture3D"]
767+
test_suite_3d.data_gen = "make_seq_tensor"
768+
test_suite_3d.dtypes = ["at::kFloat"]
769+
test_suite_2d.test_name_suffix = "3d"
770+
771+
return [test_suite_2d, test_suite_3d]
749772

750773

751774
@register_test_suite("aten.repeat_interleave.self_int")
@@ -1164,7 +1187,7 @@ def get_squeeze_copy_dim_inputs():
11641187

11651188
@register_test_suite("aten.flip.default")
11661189
def get_flip_inputs():
1167-
Test = namedtuple("VkIndexSelectTest", ["self", "dim"])
1190+
Test = namedtuple("Flip", ["self", "dim"])
11681191
Test.__new__.__defaults__ = (None, 0)
11691192

11701193
test_cases = [

0 commit comments

Comments
 (0)