Skip to content

Commit ce8359d

Browse files
authored
Use GraphBuilder in memory passes unit tests. # 1
Differential Revision: D75467583 Pull Request resolved: #11265
1 parent 489eea2 commit ce8359d

File tree

3 files changed

+324
-167
lines changed

3 files changed

+324
-167
lines changed

backends/cadence/aot/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ python_unittest(
433433
srcs = [
434434
"tests/test_memory_passes.py",
435435
],
436+
supports_static_listing = False,
436437
typing = True,
437438
deps = [
438439
":compiler",
@@ -441,7 +442,9 @@ python_unittest(
441442
":pass_utils",
442443
"//caffe2:torch",
443444
"//executorch/exir:memory",
445+
"fbsource//third-party/pypi/parameterized:parameterized",
444446
"//executorch/exir/dialects:lib",
447+
"//executorch/backends/cadence/aot:graph_builder",
445448
"//executorch/exir/tests:models",
446449
],
447450
)

backends/cadence/aot/memory_constraints.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,14 +350,28 @@ def is_slice_view(self, node: torch.fx.Node) -> bool:
350350
def is_cat_along_outermost_dim(
351351
self, graph_module: torch.fx.GraphModule, cat_node: torch.fx.Node
352352
) -> bool:
353+
assert len(cat_node.args) > 0
354+
cat_tensors = cat_node.args[0]
355+
if not isinstance(cat_tensors, Sequence) or not all(
356+
isinstance(t, torch.fx.Node) for t in cat_tensors
357+
):
358+
raise ValueError("cat_tensors must be a sequence of torch.fx.Node objects.")
359+
360+
if len(cat_node.args) > 1:
361+
cat_dim = cat_node.args[1]
362+
else:
363+
cat_dim = cat_node.kwargs.get("dim", None)
364+
if not isinstance(cat_dim, int):
365+
raise ValueError("cat_dim must be an integer.")
366+
353367
# If the cat op has default dim, then the concat dim is 0
354-
if len(cat_node.args) == 1 or cat_node.args[1] == 0:
368+
if len(cat_tensors) == 1 or cat_dim == 0:
355369
return True
356-
# Get the concatenation dimension and concatenated tensors
357-
(cat_tensors, cat_dim) = cast(
358-
tuple[Sequence[torch.fx.Node], int], cat_node.args
359-
)
370+
371+
# Make sure all dimes before cat_dim are 1.
360372
for tensor in cat_tensors:
373+
if not isinstance(tensor, torch.fx.Node):
374+
continue
361375
shape = get_shape(graph_module, tensor)
362376
if shape is None or not all(dim == 1 for dim in shape[0:cat_dim]):
363377
return False

0 commit comments

Comments
 (0)