1
1
from typing import Any , Dict , Optional , Tuple
2
2
3
3
import torch
4
- from torch ._custom_op . impl import custom_op
4
+ import torch ._custom_ops as library
5
5
from torch .fx .node import Argument , Target
6
6
from torch_tensorrt .dynamo .lowering ._pre_aot_lowering import register_substitution
7
7
from torch_tensorrt .fx .converter_registry import tensorrt_converter
20
20
# types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op
21
21
# Then, create a placeholder function with no operations, but having the same schema and naming as that
22
22
# used in the decorator
23
- @ custom_op (
24
- qualname = "tensorrt::maxpool1d" ,
25
- manual_schema = "(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor" ,
23
+ library . custom_op (
24
+ "tensorrt::maxpool1d" ,
25
+ "(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor" ,
26
26
)
27
- def maxpool1d (x , kernel_size , stride , padding , dilation , ceil_mode ): # type: ignore[no-untyped-def]
28
- # Defines operator schema, name, namespace, and function header
29
- ...
30
27
31
28
32
29
# 2. The Generic Implementation
@@ -36,9 +33,8 @@ def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode): # type: ig
36
33
# is desirable. If the operator to replace is a custom module you've written, then add its Torch
37
34
# implementation here. Note that the function header to the generic function can have specific arguments
38
35
# as in the above placeholder
39
- @maxpool1d .impl ("cpu" ) # type: ignore[misc]
40
- @maxpool1d .impl ("cuda" ) # type: ignore[misc]
41
- @maxpool1d .impl_abstract () # type: ignore[misc]
36
+ @library .impl ("tensorrt::maxpool1d" ) # type: ignore[misc]
37
+ @library .impl_abstract ("tensorrt::maxpool1d" ) # type: ignore[misc]
42
38
def maxpool1d_generic (
43
39
* args : Any ,
44
40
** kwargs : Any ,
@@ -69,7 +65,7 @@ def maxpool1d_generic(
69
65
# "bias": bias,
70
66
# ...
71
67
#
72
- @register_substitution (torch .nn .MaxPool1d , torch .ops .tensorrt .maxpool1d )
68
+ @register_substitution (torch .nn .MaxPool1d , torch .ops .tensorrt .maxpool1d ) # type: ignore
73
69
def maxpool1d_insertion_fn (
74
70
gm : torch .fx .GraphModule ,
75
71
node : torch .fx .Node ,
0 commit comments