|
9 | 9 | from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
|
10 | 10 | from torch_tensorrt.dynamo.conversion.converter_utils import (
|
11 | 11 | cast_trt_tensor,
|
| 12 | + get_axes_for_reduce_op, |
12 | 13 | get_positive_dim,
|
13 | 14 | get_trt_tensor,
|
14 | 15 | to_numpy,
|
@@ -105,102 +106,30 @@ def layer_norm(
|
105 | 106 | cudnn_enable: bool,
|
106 | 107 | return_mean_rstd: bool,
|
107 | 108 | ) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:
|
108 |
| - if weight is None: |
109 |
| - weight = to_numpy(1.0) |
110 |
| - |
111 |
| - if bias is None: |
112 |
| - bias = to_numpy(0.0) |
113 |
| - |
114 |
| - shape = weight.shape |
115 |
| - gamma = to_numpy(weight).reshape(shape) |
116 |
| - beta = to_numpy(bias).reshape(shape) |
117 |
| - |
118 |
| - dims = list(range(len(input.shape) - len(shape), len(input.shape))) |
119 |
| - |
120 |
| - # E[x] |
121 |
| - mean_expected_trt = impl.reduce.mean( |
122 |
| - ctx, target, source_ir, f"{name}_mean_expected", input, dims, True |
123 |
| - ) |
124 |
| - |
125 |
| - # X-E[x] |
126 |
| - sub_trt = impl.elementwise.sub( |
127 |
| - ctx, |
128 |
| - target, |
129 |
| - source_ir, |
130 |
| - f"{name}_sub", |
131 |
| - input, |
132 |
| - mean_expected_trt, |
133 |
| - ) |
134 |
| - |
135 |
| - # Variance = mean(pow(x_sub_mean, 2)) |
136 |
| - pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) |
137 |
| - pow_var = impl.elementwise.pow( |
138 |
| - ctx, |
139 |
| - target, |
140 |
| - source_ir, |
141 |
| - f"{name}_pow_var", |
142 |
| - sub_trt, |
143 |
| - pow_trt, |
144 |
| - ) |
145 |
| - mean_trt = impl.reduce.mean( |
146 |
| - ctx, target, source_ir, f"{name}_mean", pow_var, dims, True |
147 |
| - ) |
148 |
| - |
149 |
| - # sqrt((var + eps)) |
150 |
| - eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) |
151 |
| - add_trt = impl.elementwise.add( |
152 |
| - ctx, |
153 |
| - target, |
154 |
| - source_ir, |
155 |
| - f"{name}_add", |
156 |
| - mean_trt, |
157 |
| - eps_trt, |
158 |
| - ) |
159 |
| - sqrt_trt = impl.unary.sqrt( |
160 |
| - ctx, |
161 |
| - target, |
162 |
| - source_ir, |
163 |
| - f"{name}_sqrt", |
164 |
| - add_trt, |
165 |
| - ) |
166 |
| - |
167 |
| - # (X - E[X]) / sqrt((var + eps)) |
168 |
| - div_trt = impl.elementwise.div( |
169 |
| - ctx, |
170 |
| - target, |
171 |
| - source_ir, |
172 |
| - f"{name}_div", |
173 |
| - sub_trt, |
174 |
| - sqrt_trt, |
175 |
| - ) |
176 |
| - |
177 |
| - gamma_trt = get_trt_tensor(ctx, weight, f"{name}_gamma") |
178 |
| - beta_trt = get_trt_tensor(ctx, bias, f"{name}_beta") |
179 |
| - |
180 |
| - # y * gamma + beta |
181 |
| - scaled_y = impl.elementwise.mul( |
182 |
| - ctx, |
183 |
| - target, |
184 |
| - source_ir, |
185 |
| - f"{name}_mul_gamma", |
186 |
| - div_trt, |
187 |
| - gamma_trt, |
188 |
| - ) |
| 109 | + dims = list(range(len(input.shape) - len(normalized_shape), len(input.shape))) |
| 110 | + axes = get_axes_for_reduce_op(dims) |
| 111 | + |
| 112 | + weight = get_trt_tensor(ctx, weight, f"{name}_weight") |
| 113 | + bias = get_trt_tensor(ctx, bias, f"{name}_bias") |
| 114 | + if tuple(input.shape) != tuple(weight.shape): |
| 115 | + weight = impl.slice.expand( |
| 116 | + ctx, target, source_ir, f"{name}_expand_weight", weight, input.shape |
| 117 | + ) |
| 118 | + if tuple(input.shape) != tuple(bias.shape): |
| 119 | + bias = impl.slice.expand( |
| 120 | + ctx, target, source_ir, f"{name}_expand_bias", bias, input.shape |
| 121 | + ) |
189 | 122 |
|
190 |
| - output = impl.elementwise.add( |
191 |
| - ctx, |
192 |
| - target, |
193 |
| - source_ir, |
194 |
| - f"{name}_add_beta", |
195 |
| - scaled_y, |
196 |
| - beta_trt, |
197 |
| - ) |
| 123 | + layer_norm = ctx.net.add_normalization(input, weight, bias, axes) |
| 124 | + layer_norm.epsilon = eps |
| 125 | + layer_norm.compute_precision = input.dtype |
| 126 | + set_layer_name(layer_norm, target, f"{name}_layer_norm", source_ir) |
198 | 127 |
|
199 | 128 | if return_mean_rstd:
|
200 | 129 | # return fake mean and rstd for now
|
201 |
| - return output, None, None |
| 130 | + return layer_norm.get_output(0), None, None |
202 | 131 |
|
203 |
| - return output |
| 132 | + return layer_norm.get_output(0) |
204 | 133 |
|
205 | 134 |
|
206 | 135 | def native_group_norm(
|
|
0 commit comments