Skip to content

Commit 67502f2

Browse files
committed
[ET-VK] Enable storage type and memory layout settings to be serialized with Vulkan graph
Pull Request resolved: #2540 ## Context Allow `api::StorageType` and `api::GPUMemoryLayout` settings to be serialized with the flatbuffer. There are two entry points for this: 1. `VkTensor` table now has two fields that can be set to select particular settings for that tensor 2. A storage type and memory layout override can be set via the `CompileSpec` API ghstack-source-id: 219475440 Differential Revision: [D55154628](https://our.internmc.facebook.com/intern/diff/D55154628/)
1 parent 80e3989 commit 67502f2

File tree

7 files changed

+207
-36
lines changed

7 files changed

+207
-36
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import operator
8-
from typing import final, List, Optional
8+
from typing import Any, Dict, final, List, Optional
9+
10+
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
911

1012
import torch
1113
from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend
@@ -45,11 +47,29 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4547
return supported
4648

4749

50+
def parse_compile_options(
51+
compile_options: Optional[Dict[str, Any]] = None
52+
) -> List[CompileSpec]:
53+
compile_specs = []
54+
if compile_options is None:
55+
return compile_specs
56+
57+
for key, value in compile_options.items():
58+
if isinstance(
59+
value, (vk_graph_schema.VkStorageType, vk_graph_schema.VkMemoryLayout)
60+
):
61+
value_bytes = int(value).to_bytes(4, byteorder="little")
62+
compile_specs.append(CompileSpec(key, value_bytes))
63+
else:
64+
raise RuntimeError(f"Invalid compile option {key} with type {type(value)}")
65+
66+
return compile_specs
67+
68+
4869
@final
4970
class VulkanPartitioner(Partitioner):
50-
def __init__(self, compile_spec: Optional[List[CompileSpec]] = None) -> None:
51-
if compile_spec is None:
52-
compile_spec = []
71+
def __init__(self, compile_options: Optional[Dict[str, Any]] = None) -> None:
72+
compile_spec = parse_compile_options(compile_options)
5373
self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec)
5474

5575
def partition(self, exported_program: ExportedProgram) -> PartitionResult:

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include <cstdio>
2424
#include <cstdlib> /* strtol */
25+
#include <cstring>
2526
#include <memory>
2627
#include <type_traits>
2728
#include <vector>
@@ -72,6 +73,62 @@ api::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
7273
}
7374
}
7475

76+
api::StorageType get_storage_type(
77+
const vkgraph::VkStorageType& vk_storage_type) {
78+
switch (vk_storage_type) {
79+
case vkgraph::VkStorageType::BUFFER:
80+
return api::StorageType::BUFFER;
81+
case vkgraph::VkStorageType::TEXTURE_3D:
82+
return api::StorageType::TEXTURE_3D;
83+
case vkgraph::VkStorageType::TEXTURE_2D:
84+
return api::StorageType::TEXTURE_2D;
85+
default:
86+
break;
87+
}
88+
return api::StorageType::UNKNOWN;
89+
}
90+
91+
api::GPUMemoryLayout get_memory_layout(
92+
const vkgraph::VkMemoryLayout& vk_memory_layout) {
93+
switch (vk_memory_layout) {
94+
case vkgraph::VkMemoryLayout::TENSOR_WIDTH_PACKED:
95+
return api::GPUMemoryLayout::TENSOR_WIDTH_PACKED;
96+
case vkgraph::VkMemoryLayout::TENSOR_HEIGHT_PACKED:
97+
return api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED;
98+
case vkgraph::VkMemoryLayout::TENSOR_CHANNELS_PACKED:
99+
return api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED;
100+
default:
101+
break;
102+
}
103+
VK_THROW("Invalid memory layout encountered!");
104+
}
105+
106+
GraphConfig get_graph_config(ArrayRef<CompileSpec>& compile_specs) {
107+
GraphConfig config = GraphConfig();
108+
109+
for (const CompileSpec& spec : compile_specs) {
110+
const uint8_t* value_data = (const uint8_t*)spec.value.buffer;
111+
const size_t value_size = spec.value.nbytes;
112+
if (strcmp(spec.key, "storage_type_override") == 0) {
113+
ET_CHECK_MSG(value_size == sizeof(int32_t), "Unexpected value size!");
114+
int value_as_int = static_cast<int>(GetUInt32LE(value_data));
115+
api::StorageType storage_type =
116+
static_cast<api::StorageType>(value_as_int);
117+
118+
config.setStorageTypeOverride(storage_type);
119+
}
120+
if (strcmp(spec.key, "memory_layout_override") == 0) {
121+
ET_CHECK_MSG(value_size == sizeof(uint32_t), "Unexpected value size!");
122+
uint32_t value_as_int = GetUInt32LE(value_data);
123+
api::GPUMemoryLayout memory_layout =
124+
static_cast<api::GPUMemoryLayout>(value_as_int);
125+
126+
config.setMemoryLayoutOverride(memory_layout);
127+
}
128+
}
129+
return config;
130+
}
131+
75132
class GraphBuilder {
76133
ComputeGraph* compute_graph_;
77134
VkGraphPtr flatbuffer_;
@@ -109,10 +166,19 @@ class GraphBuilder {
109166

110167
void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) {
111168
const api::ScalarType& dtype = get_scalar_type(tensor_fb->datatype());
169+
api::StorageType storage_type =
170+
tensor_fb->storage_type() == vkgraph::VkStorageType::DEFAULT_STORAGE
171+
? compute_graph_->suggested_storage_type()
172+
: get_storage_type(tensor_fb->storage_type());
112173

113174
UIntVector dims_fb = tensor_fb->dims();
114175
const std::vector<int64_t> dims_vector(dims_fb->cbegin(), dims_fb->cend());
115176

177+
api::GPUMemoryLayout memory_layout =
178+
tensor_fb->memory_layout() == vkgraph::VkMemoryLayout::DEFAULT_LAYOUT
179+
? compute_graph_->suggested_memory_layout(dims_vector)
180+
: get_memory_layout(tensor_fb->memory_layout());
181+
116182
ValueRef ref;
117183
if (tensor_fb->constant_id() >= 0) {
118184
const uint8_t* tensor_data = getConstantDataPtr(
@@ -121,7 +187,11 @@ class GraphBuilder {
121187
ref = compute_graph_->add_tensorref(dims_vector, dtype, tensor_data);
122188
} else {
123189
ref = compute_graph_->add_tensor(
124-
dims_vector, dtype, tensor_fb->mem_obj_id());
190+
dims_vector,
191+
dtype,
192+
storage_type,
193+
memory_layout,
194+
tensor_fb->mem_obj_id());
125195
}
126196

127197
ref_mapping_[fb_id] = ref;
@@ -371,11 +441,11 @@ class VulkanBackend final : public PyTorchBackendInterface {
371441
Result<DelegateHandle*> init(
372442
BackendInitContext& context,
373443
FreeableBuffer* processed,
374-
ArrayRef<CompileSpec>) const override {
444+
ArrayRef<CompileSpec> compile_specs) const override {
375445
ComputeGraph* compute_graph = ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(
376446
context.get_runtime_allocator(), ComputeGraph);
377447

378-
new (compute_graph) ComputeGraph(GraphConfig());
448+
new (compute_graph) ComputeGraph(get_graph_config(compile_specs));
379449

380450
Error err = compileModel(processed->data(), compute_graph);
381451

backends/vulkan/runtime/VulkanDelegateHeader.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ constexpr ByteSlice kFlatbufferSize = {14, 4};
3636
constexpr ByteSlice kBytesOffset = {18, 4};
3737
constexpr ByteSlice kBytesSize = {22, 8};
3838

39+
} // namespace
40+
3941
/// Interprets the 8 bytes at `data` as a little-endian uint64_t.
4042
uint64_t GetUInt64LE(const uint8_t* data) {
4143
return (uint64_t)data[0] | ((uint64_t)data[1] << 8) |
@@ -55,8 +57,6 @@ uint32_t GetUInt16LE(const uint8_t* data) {
5557
return (uint32_t)data[0] | ((uint32_t)data[1] << 8);
5658
}
5759

58-
} // namespace
59-
6060
bool VulkanDelegateHeader::is_valid() const {
6161
if (header_size < kExpectedSize) {
6262
return false;

backends/vulkan/runtime/VulkanDelegateHeader.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ namespace torch {
1414
namespace executor {
1515
namespace vulkan {
1616

17+
// Byte decoding utilities
18+
uint64_t GetUInt64LE(const uint8_t* data);
19+
uint32_t GetUInt32LE(const uint8_t* data);
20+
uint32_t GetUInt16LE(const uint8_t* data);
21+
1722
struct VulkanDelegateHeader {
1823
bool is_valid() const;
1924

backends/vulkan/serialization/schema.fbs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,26 @@ enum VkDataType : byte {
1919
FLOAT32 = 5,
2020
}
2121

22+
// Describes what kind of GPU resource should be used to represent a tensor. The
23+
// int values assigned to each entry must match the corresponding entry in
24+
// api::StorageType.
25+
enum VkStorageType : ubyte {
26+
BUFFER = 0,
27+
TEXTURE_3D = 1,
28+
TEXTURE_2D = 2,
29+
DEFAULT_STORAGE = 255,
30+
}
31+
32+
// Describes how memory should be laid out in GPU memory. See the GPUMemoryLayout
33+
// enum class in PyTorch Vulkan for more details. The int values assigned to each
34+
// entry must match the corresponding entry in api::GPUMemoryLayout.
35+
enum VkMemoryLayout : ubyte {
36+
TENSOR_WIDTH_PACKED = 0,
37+
TENSOR_HEIGHT_PACKED = 1,
38+
TENSOR_CHANNELS_PACKED = 2,
39+
DEFAULT_LAYOUT = 255,
40+
}
41+
2242
table VkTensor {
2343
// Type of the tensor elements.
2444
datatype:VkDataType;
@@ -28,6 +48,10 @@ table VkTensor {
2848
constant_id:int;
2949
// Index to the shared memory object. Negative indicates the tensor doesn't share memory.
3050
mem_obj_id:int;
51+
// Storage type that should be used to represent this tensor
52+
storage_type:VkStorageType = DEFAULT_STORAGE;
53+
// Memory layout that should be used to represent this tensor
54+
memory_layout:VkMemoryLayout = DEFAULT_LAYOUT;
3155
}
3256

3357
table Null {}
@@ -103,6 +127,17 @@ table VkGraph {
103127
// Raw Objects (e.g. weight tensors and custom shaders)
104128
constants:[VkBytes];
105129
shaders:[VkBytes];
130+
131+
// Graph configuration
132+
// As per flatbuffer BC/FC policy, new fields can be freely added to this
133+
// section. It is recommended to provide default values, since older blobs
134+
// without the field will be deserialized with the default value.
135+
136+
// Sets an override for the storage type and memory layout that will be used
137+
// to represent a VkTensor if the VkTensor is not serialized with a particular
138+
// storage type or memory layout setting
139+
storage_type_override:VkStorageType = DEFAULT_STORAGE;
140+
memory_layout_override:VkMemoryLayout = DEFAULT_LAYOUT;
106141
}
107142

108143
root_type VkGraph;

backends/vulkan/serialization/vulkan_graph_schema.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,28 @@ class VkDataType(IntEnum):
3030
FLOAT32 = 5
3131

3232

33+
class VkStorageType(IntEnum):
34+
BUFFER = 0
35+
TEXTURE_3D = 1
36+
TEXTURE_2D = 2
37+
DEFAULT_STORAGE = 255
38+
39+
40+
class VkMemoryLayout(IntEnum):
41+
TENSOR_WIDTH_PACKED = 0
42+
TENSOR_HEIGHT_PACKED = 1
43+
TENSOR_CHANNELS_PACKED = 2
44+
DEFAULT_LAYOUT = 255
45+
46+
3347
@dataclass
3448
class VkTensor:
3549
datatype: VkDataType
3650
dims: List[int]
3751
constant_id: int
3852
mem_obj_id: int
53+
storage_type: VkStorageType = VkStorageType.DEFAULT_STORAGE
54+
memory_layout: VkMemoryLayout = VkMemoryLayout.DEFAULT_LAYOUT
3955

4056

4157
@dataclass
@@ -120,3 +136,6 @@ class VkGraph:
120136

121137
constants: List[VkBytes]
122138
shaders: List[VkBytes]
139+
140+
storage_type_override: VkStorageType = VkStorageType.DEFAULT_STORAGE
141+
memory_layout_override: VkMemoryLayout = VkMemoryLayout.DEFAULT_LAYOUT

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import unittest
99
from typing import Tuple
1010

11+
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
12+
1113
import torch
1214

1315
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
@@ -56,46 +58,66 @@ def lower_module_and_test_output(
5658
rtol=1e-01,
5759
dynamic_shapes=None,
5860
test_inputs=None,
61+
memory_layouts=None,
5962
):
6063
"""
6164
Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with
6265
the given sample inputs. It then runs the lowered module and compares its
6366
outputs with the outputs of the eager module.
6467
"""
65-
program: ExportedProgram = export(
66-
model, sample_inputs, dynamic_shapes=dynamic_shapes
67-
)
68-
edge_program: EdgeProgramManager = to_edge(program)
69-
edge_program = edge_program.to_backend(VulkanPartitioner())
7068

71-
executorch_program = edge_program.to_executorch()
69+
def run_test(memory_layout):
70+
compile_options = {
71+
"memory_layout_override": memory_layout,
72+
}
73+
program: ExportedProgram = export(
74+
model, sample_inputs, dynamic_shapes=dynamic_shapes
75+
)
76+
edge_program: EdgeProgramManager = to_edge(program)
7277

73-
self.assertEqual(
74-
executorch_program.executorch_program.execution_plan[0].delegates[0].id,
75-
VulkanBackend.__name__,
76-
)
78+
edge_program = edge_program.to_backend(VulkanPartitioner(compile_options))
7779

78-
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
79-
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
80-
inputs_flattened, _ = tree_flatten(sample_inputs)
80+
executorch_program = edge_program.to_executorch()
8181

82-
model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
83-
ref_output = model(*sample_inputs)
82+
self.assertEqual(
83+
executorch_program.executorch_program.execution_plan[0].delegates[0].id,
84+
VulkanBackend.__name__,
85+
)
8486

85-
self.assert_outputs_equal(model_output, ref_output, atol=atol, rtol=rtol)
87+
executorch_module = _load_for_executorch_from_buffer(
88+
executorch_program.buffer
89+
)
90+
inputs_flattened, _ = tree_flatten(sample_inputs)
8691

87-
if test_inputs is not None:
88-
for test_input in test_inputs:
89-
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
90-
test_inputs_flattened, _ = tree_flatten(test_input)
91-
model_output = executorch_module.run_method(
92-
"forward", tuple(test_inputs_flattened)
93-
)
94-
ref_output = model(*test_input)
92+
model_output = executorch_module.run_method(
93+
"forward", tuple(inputs_flattened)
94+
)
95+
ref_output = model(*sample_inputs)
9596

96-
self.assert_outputs_equal(
97-
model_output, ref_output, atol=atol, rtol=rtol
98-
)
97+
self.assert_outputs_equal(model_output, ref_output, atol=atol, rtol=rtol)
98+
99+
if test_inputs is not None:
100+
for test_input in test_inputs:
101+
test_inputs_flattened, _ = tree_flatten(test_input)
102+
model_output = executorch_module.run_method(
103+
"forward", tuple(test_inputs_flattened)
104+
)
105+
ref_output = model(*test_input)
106+
107+
self.assert_outputs_equal(
108+
model_output, ref_output, atol=atol, rtol=rtol
109+
)
110+
111+
memory_layouts_to_test = [
112+
vk_graph_schema.VkMemoryLayout.TENSOR_WIDTH_PACKED,
113+
vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED,
114+
]
115+
116+
if memory_layouts is not None:
117+
memory_layouts_to_test = memory_layouts
118+
119+
for memory_layout in memory_layouts_to_test:
120+
run_test(memory_layout)
99121

100122
def test_vulkan_backend_add(self):
101123
# This test is the simplest test by manually lowering some submodules, we can use paritioner for auto detecting lowerable parts

0 commit comments

Comments
 (0)