@@ -46,31 +46,22 @@ def get_ir(target: Target) -> SourceIR:
46
46
return SourceIR .UNKNOWN
47
47
48
48
49
- @dynamo_tensorrt_converter (torch .ops .aten .native_batch_norm .default ) # type: ignore[misc]
50
- def aten_ops_native_batch_norm (
51
- ctx : ConversionContext ,
52
- target : Target ,
53
- args : Tuple [Argument , ...],
54
- kwargs : Dict [str , Argument ],
55
- name : str ,
56
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
57
- return impl .normalization .native_batch_norm (
58
- ctx ,
59
- target ,
60
- SourceIR .ATEN ,
61
- name ,
62
- input = args [0 ],
63
- weight = args [1 ],
64
- bias = args [2 ],
65
- running_mean = args [3 ],
66
- running_var = args [4 ],
67
- training = args [5 ],
68
- momentum = args [6 ],
69
- eps = args [7 ],
49
+ def one_user_validator (node : Node ) -> bool :
50
+ # Validate only one user, which is a getitem node that accesses the first element in the list
51
+ return (
52
+ len (node .users ) == 1
53
+ and list (node .users )[0 ].target == operator .getitem
54
+ and list (node .users )[0 ].args [1 ] == 0
70
55
)
71
56
72
57
73
- @dynamo_tensorrt_converter (torch .ops .aten .batch_norm ) # type: ignore[misc]
58
+ @dynamo_tensorrt_converter (torch .ops .aten .native_batch_norm .default , capability_validator = one_user_validator ) # type: ignore[misc]
59
+ @dynamo_tensorrt_converter (torch .ops .aten .batch_norm .default ) # type: ignore[misc]
60
+ @enforce_tensor_types (
61
+ {
62
+ 0 : (TRTTensor ,),
63
+ }
64
+ ) # type: ignore[misc]
74
65
def aten_ops_batch_norm (
75
66
ctx : ConversionContext ,
76
67
target : Target ,
@@ -91,32 +82,18 @@ def aten_ops_batch_norm(
91
82
training = args [5 ],
92
83
momentum = args [6 ],
93
84
eps = args [7 ],
94
- cudnn_enabled = args [8 ],
95
- )
96
-
97
-
98
- @dynamo_tensorrt_converter (torch .ops .aten .native_layer_norm .default ) # type: ignore[misc]
99
- def aten_ops_native_layer_norm (
100
- ctx : ConversionContext ,
101
- target : Target ,
102
- args : Tuple [Argument , ...],
103
- kwargs : Dict [str , Argument ],
104
- name : str ,
105
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
106
- return impl .normalization .native_layer_norm (
107
- ctx ,
108
- target ,
109
- SourceIR .ATEN ,
110
- name ,
111
- input = args [0 ],
112
- normalized_shape = args [1 ],
113
- weight = args [2 ],
114
- bias = args [3 ],
115
- eps = args [4 ],
85
+ cudnn_enabled = args_bounds_check (args , 8 , True ),
86
+ return_mean_rstd = (target == torch .ops .aten .native_batch_norm .default ),
116
87
)
117
88
118
89
90
+ @dynamo_tensorrt_converter (torch .ops .aten .native_layer_norm .default , capability_validator = one_user_validator ) # type: ignore[misc]
119
91
@dynamo_tensorrt_converter (torch .ops .aten .layer_norm .default ) # type: ignore[misc]
92
+ @enforce_tensor_types (
93
+ {
94
+ 0 : (TRTTensor ,),
95
+ }
96
+ ) # type: ignore[misc]
120
97
def aten_ops_layer_norm (
121
98
ctx : ConversionContext ,
122
99
target : Target ,
@@ -135,10 +112,16 @@ def aten_ops_layer_norm(
135
112
bias = args_bounds_check (args , 3 ),
136
113
eps = args_bounds_check (args , 4 , 1e-05 ),
137
114
cudnn_enable = args_bounds_check (args , 5 , True ),
115
+ return_mean_rstd = (target == torch .ops .aten .native_layer_norm .default ),
138
116
)
139
117
140
118
141
- @dynamo_tensorrt_converter (torch .ops .aten .native_group_norm .default ) # type: ignore[misc]
119
+ @dynamo_tensorrt_converter (torch .ops .aten .native_group_norm .default , capability_validator = one_user_validator ) # type: ignore[misc]
120
+ @enforce_tensor_types (
121
+ {
122
+ 0 : (TRTTensor ,),
123
+ }
124
+ ) # type: ignore[misc]
142
125
def aten_ops_native_group_norm (
143
126
ctx : ConversionContext ,
144
127
target : Target ,
@@ -163,6 +146,11 @@ def aten_ops_native_group_norm(
163
146
164
147
165
148
@dynamo_tensorrt_converter (torch .ops .aten .group_norm .default ) # type: ignore[misc]
149
+ @enforce_tensor_types (
150
+ {
151
+ 0 : (TRTTensor ,),
152
+ }
153
+ ) # type: ignore[misc]
166
154
def aten_ops_group_norm (
167
155
ctx : ConversionContext ,
168
156
target : Target ,
@@ -838,15 +826,6 @@ def aten_ops_prod(
838
826
)
839
827
840
828
841
- def one_user_validator (node : Node ) -> bool :
842
- # Validate only one user, which is a getitem node that accesses the first element in the list
843
- return (
844
- len (node .users ) == 1
845
- and list (node .users )[0 ].target == operator .getitem
846
- and list (node .users )[0 ].args [1 ] == 0
847
- )
848
-
849
-
850
829
@dynamo_tensorrt_converter (torch .ops .aten .max .default ) # type: ignore[misc]
851
830
@dynamo_tensorrt_converter (torch .ops .aten .max .dim , capability_validator = one_user_validator ) # type: ignore[misc]
852
831
def aten_ops_max (
0 commit comments