Skip to content

Commit bece720

Browse files
authored
Update usage of PyTorch's custom op API (#2193)
1 parent a2a983b commit bece720

File tree

2 files changed

+13
-22
lines changed

2 files changed

+13
-22
lines changed

py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,21 @@
11
from typing import Any, Dict, Optional, Sequence, Tuple
22

33
import torch
4-
from torch._custom_op.impl import custom_op
4+
import torch._custom_ops as library
55
from torch.fx.node import Argument, Target
66
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
77
from torch_tensorrt.fx.converter_registry import tensorrt_converter
88
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
99
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1010

11-
12-
@custom_op(
13-
qualname="tensorrt::einsum",
14-
manual_schema="(str equation, Tensor[] tensors) -> Tensor",
11+
library.custom_op(
12+
"tensorrt::einsum",
13+
"(str equation, Tensor[] tensors) -> Tensor",
1514
)
16-
def einsum(equation, tensors): # type: ignore[no-untyped-def]
17-
# Defines operator schema, name, namespace, and function header
18-
...
1915

2016

21-
@einsum.impl("cpu") # type: ignore[misc]
22-
@einsum.impl("cuda") # type: ignore[misc]
23-
@einsum.impl_abstract() # type: ignore[misc]
17+
@library.impl("tensorrt::einsum") # type: ignore[misc]
18+
@library.impl_abstract("tensorrt::einsum") # type: ignore[misc]
2419
def einsum_generic(
2520
*args: Any,
2621
**kwargs: Any,

py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Dict, Optional, Tuple
22

33
import torch
4-
from torch._custom_op.impl import custom_op
4+
import torch._custom_ops as library
55
from torch.fx.node import Argument, Target
66
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
77
from torch_tensorrt.fx.converter_registry import tensorrt_converter
@@ -20,13 +20,10 @@
2020
# types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op
2121
# Then, create a placeholder function with no operations, but having the same schema and naming as that
2222
# used in the decorator
23-
@custom_op(
24-
qualname="tensorrt::maxpool1d",
25-
manual_schema="(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor",
23+
library.custom_op(
24+
"tensorrt::maxpool1d",
25+
"(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor",
2626
)
27-
def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode): # type: ignore[no-untyped-def]
28-
# Defines operator schema, name, namespace, and function header
29-
...
3027

3128

3229
# 2. The Generic Implementation
@@ -36,9 +33,8 @@ def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode): # type: ig
3633
# is desirable. If the operator to replace is a custom module you've written, then add its Torch
3734
# implementation here. Note that the function header to the generic function can have specific arguments
3835
# as in the above placeholder
39-
@maxpool1d.impl("cpu") # type: ignore[misc]
40-
@maxpool1d.impl("cuda") # type: ignore[misc]
41-
@maxpool1d.impl_abstract() # type: ignore[misc]
36+
@library.impl("tensorrt::maxpool1d") # type: ignore[misc]
37+
@library.impl_abstract("tensorrt::maxpool1d") # type: ignore[misc]
4238
def maxpool1d_generic(
4339
*args: Any,
4440
**kwargs: Any,
@@ -69,7 +65,7 @@ def maxpool1d_generic(
6965
# "bias": bias,
7066
# ...
7167
#
72-
@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
68+
@register_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) # type: ignore
7369
def maxpool1d_insertion_fn(
7470
gm: torch.fx.GraphModule,
7571
node: torch.fx.Node,

0 commit comments

Comments
 (0)