Skip to content

Commit 0f2995f

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK] Enforce GPU buffer limit when partitioning (#6856)
Pull Request resolved: #6829 ## Context In Vulkan, there is a limit on the number of elements a GPU buffer can have. If a GPU buffer exceeds this limit, then the API will either produce an error or undefined behaviour will ensue. ## Changes Along with `texture_limits`, introduce a configurable `buffer_limit` entry in the partitioner configuration. ghstack-source-id: 253568943 Differential Revision: [D65899828](https://our.internmc.facebook.com/intern/diff/D65899828/) Co-authored-by: Stephen Jia <[email protected]>
1 parent 07c4d0e commit 0f2995f

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,15 @@
5151

5252
class VulkanSupportedOperators(OperatorSupportBase):
5353
def __init__(
54-
self, texture_limits: utils.ImageExtents, require_dynamic_shape: bool = False
54+
self,
55+
texture_limits: utils.ImageExtents,
56+
buffer_limit: int,
57+
require_dynamic_shape: bool = False,
5558
) -> None:
5659
super().__init__()
57-
self.require_dynamic_shapes = require_dynamic_shape
5860
self.texture_limits: utils.ImageExtents = texture_limits
61+
self.buffer_limit = buffer_limit
62+
self.require_dynamic_shapes = require_dynamic_shape
5963

6064
def op_node_is_compatible(
6165
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
@@ -83,6 +87,7 @@ def op_node_is_compatible(
8387
node, self.texture_limits
8488
)
8589

90+
can_use_buffers = utils.within_buffer_limit(node, self.buffer_limit)
8691
for i, arg in enumerate(node.args):
8792
if (
8893
isinstance(arg, torch.fx.Node)
@@ -95,10 +100,19 @@ def op_node_is_compatible(
95100
valid_texture_layouts = valid_texture_layouts.intersection(
96101
arg_texture_layouts
97102
)
103+
can_use_buffers = can_use_buffers and utils.within_buffer_limit(
104+
arg, self.buffer_limit
105+
)
98106

99107
# If there are no valid texture memory layouts, then buffer storage must be
100108
# supported by the operator implementation.
101109
if len(valid_texture_layouts) == 0:
110+
if not can_use_buffers:
111+
return (
112+
False,
113+
f"op requires buffers that exceed the buffer limit ({self.buffer_limit})",
114+
)
115+
102116
compatible = VkStorageType.BUFFER in features.supported_storage_types()
103117
reason = "op is compatible"
104118
if not compatible:
@@ -309,10 +323,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
309323
texture_limits: utils.ImageExtents = self.options.get(
310324
"texture_limits", utils.DEFAULT_TEXTURE_LIMITS
311325
)
326+
buffer_limit: int = self.options.get("buffer_limit", utils.DEFAULT_BUFFER_LIMIT)
312327
capability_partitioner = CapabilityBasedPartitioner(
313328
exported_program.graph_module,
314329
VulkanSupportedOperators(
315330
texture_limits,
331+
buffer_limit,
316332
require_dynamic_shape=self.options.get("require_dynamic_shapes", False),
317333
),
318334
allows_single_node_partition=True,

backends/vulkan/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def is_tensor_node(node: torch.fx.Node) -> bool:
8787
ImageExtents = Tuple[int, int, int]
8888

8989
DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048)
90+
DEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024)
9091

9192

9293
class PackedDim(IntEnum):
@@ -113,6 +114,22 @@ class PackedDim(IntEnum):
113114
}
114115

115116

117+
def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int:
118+
"""
119+
Checks whether the tensors produced by the given node can fit within the device's
120+
GPU buffer limit, which represents the maximum number of elements that can be stored
121+
in a GPU buffer.
122+
"""
123+
assert is_tensor_node(node)
124+
125+
if isinstance(node.meta["val"], FakeTensor):
126+
return node.meta["val"].numel() < buffer_limit
127+
elif isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
128+
return all(x.numel() < buffer_limit for x in node.meta["val"])
129+
else:
130+
raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}")
131+
132+
116133
def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents:
117134
"""
118135
Calculate the image extents that will be used to represent a tensor with the given sizes

0 commit comments

Comments
 (0)