Skip to content

Commit cd02c85

Browse files
anijain2305pytorchmergebot
authored andcommitted
[inductor][subgraph][python-wrapper] Lift subgraph code into functions (pytorch#137200)
Earlier the subgraphs were getting inlined into the output code. This PR lifts the subgraphs into a function, and then we just call the function in the output code. This is the output code for test `test_cond_reintepret_view_inputs_outputs` Before this PR - https://www.internalfb.com/intern/paste/P1632948905/ With this PR - https://www.internalfb.com/intern/paste/P1632946348/ A relevant snippet from the above paste is ~~~ def false_graph_0(args): false_graph_0_arg0_1, false_graph_0_arg1_1, s0 = args args.clear() s0 = s0 with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) false_graph_0_buf0 = empty_strided_cuda(((-1) + s0, 20), (20, 1), torch.float32) false_graph_0_buf1 = empty_strided_cuda(((-1) + s0, 20), (20, 1), torch.float32) # Unsorted Source Nodes: [cond, z1, z2], Original ATen: [aten.sub, aten.add] triton_poi_fused_add_sub_1_xnumel = (-20) + (20*s0) stream0 = get_raw_stream(0) triton_poi_fused_add_sub_1.run(false_graph_0_arg0_1, false_graph_0_arg1_1, false_graph_0_buf0, false_graph_0_buf1, triton_poi_fused_add_sub_1_xnumel, grid=grid(triton_poi_fused_add_sub_1_xnumel), stream=stream0) del false_graph_0_arg0_1 del false_graph_0_arg1_1 return (reinterpret_tensor(false_graph_0_buf0, ((-3) + s0, 20), (20, 1), 40), reinterpret_tensor(false_graph_0_buf1, ((-1) + s0, 16), (20, 1), 4), ) async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1, arg2_1, arg3_1 = args args.clear() s0 = arg0_1 assert_size_stride(arg1_1, (s0, 20), (20, 1)) assert_size_stride(arg2_1, (s0, 20), (20, 1)) assert_size_stride(arg3_1, (), ()) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) buf0 = [None] * 2 buf0 = [None] * 2 if arg3_1.item(): # subgraph: true_graph_0 true_graph_0_arg0_1 = reinterpret_tensor(arg1_1, ((-1) + s0, 20), (20, 1), 0) true_graph_0_arg1_1 = reinterpret_tensor(arg2_1, ((-1) + s0, 20), (20, 1), 0) (true_graph_0_buf0, true_graph_0_buf1) = true_graph_0([true_graph_0_arg0_1, true_graph_0_arg1_1, s0]) buf0[0] = true_graph_0_buf0 buf0[1] = true_graph_0_buf1 else: # subgraph: false_graph_0 false_graph_0_arg0_1 = reinterpret_tensor(arg1_1, ((-1) + s0, 20), (20, 1), 0) false_graph_0_arg1_1 = reinterpret_tensor(arg2_1, ((-1) + s0, 20), (20, 1), 0) (false_graph_0_buf0, false_graph_0_buf1) = false_graph_0([false_graph_0_arg0_1, false_graph_0_arg1_1, s0]) buf0[0] = false_graph_0_buf0 buf0[1] = false_graph_0_buf1 del arg1_1 del arg2_1 del arg3_1 buf1 = buf0[0] buf2 = buf0[1] del buf0 return (buf1, buf2, ) ~~~ The key change is to recursively call `codegen` for the subgraph and rely on `SubgraphPythonWrapper` to generate just the subgraph `fn`. The resulting subgraph_code is then inserted into the parent wrapper. Note that this PR only works for python wrapper. For cpp wrapper, we need a lot of refactor to ensure that we don't duplicate the global variables in the outpute_code. So, for now, I fallback to the old way of inlining for cpp wrapper. I am hoping someone with more familiarity with cpp wrapper can support subgraph lifting (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov). This work will unblock hierarchical compilation (or cold start compile time work). Pull Request resolved: pytorch#137200 Approved by: https://github.com/desertfire, https://github.com/eellison
1 parent 68272ab commit cd02c85

File tree

6 files changed

+405
-26
lines changed

6 files changed

+405
-26
lines changed

test/inductor/test_control_flow.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,98 @@ def false_fn(x, y):
335335
dynamic=True,
336336
)
337337

338+
@requires_gpu
339+
@parametrize("device", ["cpu", GPU_TYPE])
340+
def test_cond_unbacked_symint_outer_to_inner(self, device):
341+
class Model(torch.nn.Module):
342+
def forward(self, p, a):
343+
def true_fn(x):
344+
return torch.cos(x)
345+
346+
def false_fn(x):
347+
return torch.sin(x)
348+
349+
nz = torch.nonzero(a)
350+
b = torch.ones([nz.size(0), 8], device=nz.device)
351+
352+
return torch.cond(p, true_fn, false_fn, [b])
353+
354+
with torch._dynamo.config.patch(
355+
{
356+
"capture_dynamic_output_shape_ops": True,
357+
}
358+
):
359+
self._run_test(
360+
model=Model(),
361+
inputs=(torch.randn(2, 3, 3),),
362+
device=device,
363+
dynamic=True,
364+
)
365+
366+
@requires_gpu
367+
@parametrize("device", ["cpu", GPU_TYPE])
368+
def test_cond_unbacked_symint_inner(self, device):
369+
class Model(torch.nn.Module):
370+
def forward(self, p, a):
371+
def true_fn(x):
372+
nz = torch.nonzero(x)
373+
b = torch.ones([nz.size(0), 8], device=nz.device)
374+
return torch.cos(b)
375+
376+
def false_fn(x):
377+
nz = torch.nonzero(x)
378+
b = torch.ones([nz.size(0), 8], device=nz.device)
379+
return torch.sin(b)
380+
381+
b = torch.sin(a)
382+
383+
return torch.cond(p, true_fn, false_fn, [b])
384+
385+
with torch._dynamo.config.patch(
386+
{
387+
"capture_dynamic_output_shape_ops": True,
388+
}
389+
):
390+
self._run_test(
391+
model=Model(),
392+
inputs=(torch.randn(2, 3, 3),),
393+
device=device,
394+
dynamic=True,
395+
)
396+
397+
@unittest.skip("unbacked symints from inner to outer graph not supported yet")
398+
@requires_gpu
399+
@parametrize("device", ["cpu", GPU_TYPE])
400+
def test_cond_unbacked_symint_inner_to_outer(self, device):
401+
class Model(torch.nn.Module):
402+
def forward(self, p, a):
403+
def true_fn(x):
404+
nz = torch.nonzero(x)
405+
b = torch.ones([nz.size(0), 8], device=nz.device)
406+
return torch.cos(b)
407+
408+
def false_fn(x):
409+
nz = torch.nonzero(x)
410+
b = torch.ones([nz.size(0), 8], device=nz.device)
411+
return torch.sin(b)
412+
413+
b = torch.sin(a)
414+
415+
y = torch.cond(p, true_fn, false_fn, [b])
416+
return torch.sin(y)
417+
418+
with torch._dynamo.config.patch(
419+
{
420+
"capture_dynamic_output_shape_ops": True,
421+
}
422+
):
423+
self._run_test(
424+
model=Model(),
425+
inputs=(torch.randn(2, 3, 3),),
426+
device=device,
427+
dynamic=True,
428+
)
429+
338430
@requires_gpu
339431
def test_cond_use_buffers_from_outer_scope(self):
340432
# subgraphs input shapes include symbolic expressions

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ def __init__(self):
7070
self.initialized_kernels: Dict[str, Kernel] = {}
7171
self.expr_printer = cexpr
7272

73+
@staticmethod
74+
def create(
75+
is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen
76+
):
77+
# TODO - support subgraph codegen by lifting functions. Check the
78+
# comment at CppWrapperCpu `codegen_subgraph` function.
79+
return CppWrapperCpu()
80+
7381
def generate_kernel_call(
7482
self,
7583
kernel_name: str,
@@ -1912,6 +1920,25 @@ def codegen_conditional(self, conditional):
19121920
self.writeline(ExitSubgraphLine(self))
19131921
self.writeline("}")
19141922

1923+
def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
1924+
# TODO (desertfire) - This function is the old way of supporting
1925+
# subgraph codegen by inlining subgraphs in the output code. For python
1926+
# wrapper, we have moved to lifting subgraphs as functions, supported by
1927+
# PythonWrapperCode `codegen_subgraph` function. We should perhaps
1928+
# support lifting of subgraphs as functions for cpp wrapper as well.
1929+
try:
1930+
self.push_codegened_graph(subgraph.graph)
1931+
self.writeline(f"{self.comment} subgraph: {subgraph.name}")
1932+
self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs)
1933+
parent_graph = V.graph
1934+
with V.set_graph_handler(subgraph.graph):
1935+
subgraph.graph.codegen_subgraph(
1936+
parent_graph=parent_graph,
1937+
)
1938+
self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs)
1939+
finally:
1940+
self.pop_codegened_graph()
1941+
19151942
def codegen_while_loop(self, while_loop):
19161943
name = while_loop.get_name()
19171944
outer_carried_inputs = [

torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ExitSubgraphLine,
1616
MemoryPlanningLine,
1717
MemoryPlanningState,
18+
PythonWrapperCodegen,
1819
)
1920

2021

@@ -72,6 +73,14 @@ def __init__(self):
7273
self.allow_stack_allocation: Optional[bool] = None
7374
self.stack_allocated_buffers: Dict[BufferName, ir.Buffer] = {}
7475

76+
@staticmethod
77+
def create(
78+
is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen
79+
):
80+
# TODO - support subgraph codegen by lifting functions. Check the
81+
# comment at CppWrapperCpu `codegen_subgraph` function.
82+
return CppWrapperCpuArrayRef()
83+
7584
def memory_plan(self):
7685
from .memory_planning import MemoryPlanner
7786

torch/_inductor/codegen/cpp_wrapper_gpu.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .common import get_device_op_overrides
1919
from .cpp_utils import cexpr, DTYPE_TO_CPP
2020
from .cpp_wrapper_cpu import CppWrapperCpu
21-
from .wrapper import SymbolicCallArg
21+
from .wrapper import PythonWrapperCodegen, SymbolicCallArg
2222

2323

2424
if TYPE_CHECKING:
@@ -171,6 +171,14 @@ def __init__(self) -> None:
171171
super().__init__()
172172
self.grid_id = count()
173173

174+
@staticmethod
175+
def create(
176+
is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen
177+
):
178+
# TODO - support subgraph codegen by lifting functions. Check the
179+
# comment at CppWrapperCpu `codegen_subgraph` function.
180+
return CppWrapperGpu()
181+
174182
def write_header(self):
175183
if V.graph.is_const_graph:
176184
# We do not write header for constant graph, it will be written by main module.

0 commit comments

Comments
 (0)