Skip to content

Commit 730645d

Browse files
alexbeloiWei Wei
authored andcommitted
[fx][acc_ops] add acc_ops.gather and acc_ops.index_select and shape inference (#30)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/30 as title Reviewed By: 842974287 Differential Revision: D34874487 fbshipit-source-id: 86039d0f1269d879983977c65fe8fbe0a8bc1421
1 parent 6165038 commit 730645d

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

test/tracer/test_acc_tracer.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,6 +2102,78 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
21022102
gm_retrace = acc_tracer.trace(gm, [a])
21032103
self.assertTrue(torch.equal(m(a), gm_retrace(a)))
21042104

2105+
def test_index_select(self):
2106+
class TestModule(nn.Module):
2107+
def __init__(self, dim, index):
2108+
super().__init__()
2109+
self._dim = dim
2110+
self._index = index
2111+
2112+
def forward(self, a: torch.Tensor) -> torch.Tensor:
2113+
return torch.index_select(a, self._dim, self._index)
2114+
2115+
dim = 0
2116+
index = torch.tensor([1, 0])
2117+
m = TestModule(dim, index)
2118+
_input = [torch.randn(2, 3), torch.randn(2, 3)]
2119+
traced = acc_tracer.trace(m, _input)
2120+
2121+
ph = index = index_select = None
2122+
2123+
for node in traced.graph.nodes:
2124+
if node.op == "placeholder":
2125+
self.assertEqual(str(node.target), "a")
2126+
ph = node
2127+
elif node.op == "call_function" and node.target == acc_ops.index_select:
2128+
self.assertTrue(node.kwargs["input"] == ph)
2129+
self.assertTrue(node.kwargs["index"] == index)
2130+
self.assertTrue(node.kwargs["dim"] == dim)
2131+
index_select = node
2132+
elif node.op == "output":
2133+
self.assertEqual(index_select, node.args[0])
2134+
elif node.op == "get_attr":
2135+
# There only be one™ const node
2136+
self.assertTrue(index is None)
2137+
index = node
2138+
else:
2139+
self.fail(f"Unexpected node: {node.format_node()}")
2140+
2141+
def test_gather(self):
2142+
class TestModule(nn.Module):
2143+
def __init__(self, dim, index):
2144+
super().__init__()
2145+
self._dim = dim
2146+
self._index = index
2147+
2148+
def forward(self, a: torch.Tensor) -> torch.Tensor:
2149+
return torch.gather(a, self._dim, self._index)
2150+
2151+
dim = 0
2152+
index = torch.tensor([[1, 0], [0, 1]])
2153+
m = TestModule(dim, index)
2154+
_input = [torch.randn(2, 3), torch.randn(2, 3)]
2155+
traced = acc_tracer.trace(m, _input)
2156+
2157+
ph = index = gather = None
2158+
2159+
for node in traced.graph.nodes:
2160+
if node.op == "placeholder":
2161+
self.assertEqual(str(node.target), "a")
2162+
ph = node
2163+
elif node.op == "call_function" and node.target == acc_ops.gather:
2164+
self.assertTrue(node.kwargs["input"] == ph)
2165+
self.assertTrue(node.kwargs["index"] == index)
2166+
self.assertTrue(node.kwargs["dim"] == dim)
2167+
gather = node
2168+
elif node.op == "output":
2169+
self.assertEqual(gather, node.args[0])
2170+
elif node.op == "get_attr":
2171+
# There only be one™ const node
2172+
self.assertTrue(index is None)
2173+
index = node
2174+
else:
2175+
self.fail(f"Unexpected node: {node.format_node()}")
2176+
21052177
def test_all_acc_ops_registered(self):
21062178
self.assertEqual(
21072179
acc_normalizer._acc_ops,
@@ -2203,5 +2275,7 @@ def test_all_acc_ops_registered(self):
22032275
acc_ops.eq,
22042276
acc_ops.gt,
22052277
acc_ops.le,
2278+
acc_ops.gather,
2279+
acc_ops.index_select,
22062280
},
22072281
)

tracer/acc_tracer/acc_ops.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,15 @@ def adaptive_avg_pool2d(*, input, output_size):
130130

131131
@register_acc_op_mapping(op_and_target=("call_function", nn.functional.avg_pool1d))
132132
@register_acc_op
133-
def avg_pool1d(*, input, kernel_size, stride, padding, ceil_mode, count_include_pad):
133+
def avg_pool1d(
134+
*,
135+
input,
136+
kernel_size,
137+
stride,
138+
padding,
139+
ceil_mode,
140+
count_include_pad,
141+
):
134142
return nn.functional.avg_pool1d(
135143
input=input,
136144
kernel_size=kernel_size,
@@ -2163,3 +2171,24 @@ def cumsum(*, input, dim, dtype=None):
21632171
@register_acc_op
21642172
def chunk(*, input, chunks, dim=0):
21652173
return torch.chunk(input=input, chunks=chunks, dim=dim)
2174+
2175+
2176+
@register_acc_op_mapping(op_and_target=("call_function", torch.gather),
2177+
arg_replacement_tuples=[
2178+
("input", "input"),
2179+
("dim", "dim"),
2180+
("index", "index"),
2181+
("sparse_grad", "sparse_grad", this_arg_is_optional),
2182+
],
2183+
)
2184+
@register_acc_op
2185+
def gather(*, input, dim, index, sparse_grad=False):
2186+
return torch.gather(input=input, dim=dim, index=index, sparse_grad=sparse_grad)
2187+
2188+
2189+
@register_acc_op_mapping(
2190+
op_and_target=("call_function", torch.index_select),
2191+
)
2192+
@register_acc_op
2193+
def index_select(*, input, dim, index):
2194+
return torch.index_select(input, dim, index)

0 commit comments

Comments
 (0)