1
+ from typing import Callable , Dict , Any
1
2
import torch
2
- from torch ._decomp import register_decomposition , core_aten_decompositions
3
+ from torch ._decomp import register_decomposition , core_aten_decompositions , OpOverload
3
4
4
5
5
- DECOMPOSITIONS = {** core_aten_decompositions ()}
6
+ DECOMPOSITIONS : Dict [ OpOverload , Callable [[ Any ], Any ]] = {** core_aten_decompositions ()}
6
7
7
8
aten = torch .ops .aten
8
9
9
-
10
- def replace_inplace_op (aten_op , outplace_op ):
10
+ def replace_inplace_op (aten_op : OpOverload , outplace_op : OpOverload ) -> Any :
11
11
"""Replace inplace operation with functional equivalent
12
12
Adapted from:
13
13
https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361
14
14
"""
15
15
16
- @register_decomposition (aten_op , registry = DECOMPOSITIONS )
17
- def inplace_op (* args , ** kwargs ) :
16
+ @register_decomposition (aten_op , registry = DECOMPOSITIONS ) # type: ignore[misc]
17
+ def inplace_op (* args : Any , ** kwargs : Any ) -> Any :
18
18
out = outplace_op (* args , ** kwargs )
19
19
return args [0 ].copy_ (out )
20
20
@@ -36,29 +36,29 @@ def inplace_op(*args, **kwargs):
36
36
replace_inplace_op (aten .scatter_reduce_ , aten .scatter_reduce )
37
37
38
38
39
- @register_decomposition (aten .std , registry = DECOMPOSITIONS )
40
- def std_replacement (* args , ** kwargs ) -> torch .Tensor :
39
+ @register_decomposition (aten .std , registry = DECOMPOSITIONS ) # type: ignore[misc]
40
+ def std_replacement (* args : Any , ** kwargs : Any ) -> torch .Tensor :
41
41
return torch .sqrt (torch .var (* args , ** kwargs ))
42
42
43
43
44
- @register_decomposition (aten .rsqrt , registry = DECOMPOSITIONS )
45
- def rsqrt_replacement (* args , ** kwargs ) -> torch .Tensor :
44
+ @register_decomposition (aten .rsqrt , registry = DECOMPOSITIONS ) # type: ignore[misc]
45
+ def rsqrt_replacement (* args : Any , ** kwargs : Any ) -> torch .Tensor :
46
46
return torch .reciprocal (torch .sqrt (* args , ** kwargs ))
47
47
48
48
49
- @register_decomposition (aten .alias , registry = DECOMPOSITIONS )
49
+ @register_decomposition (aten .alias , registry = DECOMPOSITIONS ) # type: ignore[misc]
50
50
def alias_replacement (x : torch .Tensor ) -> torch .Tensor :
51
51
return x
52
52
53
53
54
- @register_decomposition (torch .ops .aten .addmm , registry = DECOMPOSITIONS )
54
+ @register_decomposition (torch .ops .aten .addmm , registry = DECOMPOSITIONS ) # type: ignore[misc]
55
55
def addmm_replacement (
56
- input_ : torch .Tensor , mat1 : torch .Tensor , mat2 : torch .Tensor , * , beta = 1 , alpha = 1
56
+ input_ : torch .Tensor , mat1 : torch .Tensor , mat2 : torch .Tensor , * , beta : int = 1 , alpha : int = 1
57
57
) -> torch .Tensor :
58
58
return torch .add (
59
59
torch .mul (input_ , beta ), torch .mul (torch .matmul (mat1 , mat2 ), alpha )
60
60
)
61
61
62
62
63
- def get_decompositions ():
63
+ def get_decompositions () -> Dict [ OpOverload , Callable [[ Any ], Any ]] :
64
64
return DECOMPOSITIONS
0 commit comments