|
13 | 13 | )
|
14 | 14 | from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
|
15 | 15 | from torch_tensorrt.fx.converters.converter_utils import (
|
16 |
| - get_trt_plugin, |
17 |
| - get_trt_tensor, |
| 16 | + get_positive_dim, |
18 | 17 | has_dynamic_shape,
|
19 | 18 | set_layer_name,
|
20 | 19 | )
|
@@ -53,10 +52,7 @@ def native_batch_norm(
|
53 | 52 | if running_var is None:
|
54 | 53 | running_var = 1.0
|
55 | 54 |
|
56 |
| - scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt( |
57 |
| - cast(torch.Tensor, to_numpy(running_var)) + eps |
58 |
| - ) |
59 |
| - |
| 55 | + scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps) |
60 | 56 | bias = to_numpy(bias) - to_numpy(running_mean) * scale
|
61 | 57 | power = np.ones_like(scale)
|
62 | 58 |
|
@@ -135,78 +131,6 @@ def layer_norm(
|
135 | 131 | eps: float,
|
136 | 132 | cudnn_enable: bool,
|
137 | 133 | ) -> Union[TRTTensor, Sequence[TRTTensor]]:
|
138 |
| - if not isinstance(input, trt.tensorrt.ITensor): |
139 |
| - raise RuntimeError( |
140 |
| - f"LayerNorm received input {input} that is not part " |
141 |
| - "of the TensorRT region!" |
142 |
| - ) |
143 |
| - |
144 |
| - if weight is None: |
145 |
| - weight = to_numpy(1.0) |
146 |
| - |
147 |
| - if bias is None: |
148 |
| - bias = to_numpy(0.0) |
149 |
| - |
150 |
| - gamma = ( |
151 |
| - weight.detach().cpu().float().numpy() |
152 |
| - if isinstance(weight, torch.Tensor) |
153 |
| - else weight |
154 |
| - ) |
155 |
| - gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32) |
156 |
| - beta = ( |
157 |
| - bias.detach().cpu().float().numpy() if isinstance(bias, torch.Tensor) else bias |
158 |
| - ) |
159 |
| - beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32) |
160 |
| - eps_field = trt.PluginField( |
161 |
| - "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 |
162 |
| - ) |
163 |
| - try: |
164 |
| - normalized_shape_arr = np.array(normalized_shape, dtype=np.int32) |
165 |
| - except TypeError: |
166 |
| - _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") |
167 |
| - normalized_shape_arr = np.array([], dtype=np.int32) |
168 |
| - |
169 |
| - normalized_shape_filed = trt.PluginField( |
170 |
| - "normalized_shape", normalized_shape_arr, trt.PluginFieldType.INT32 |
171 |
| - ) |
172 |
| - field_collection = trt.PluginFieldCollection( |
173 |
| - [gamma_field, beta_field, eps_field, normalized_shape_filed] |
174 |
| - ) |
175 |
| - |
176 |
| - try: |
177 |
| - if ctx.net.has_implicit_batch_dimension: |
178 |
| - plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt") |
179 |
| - else: |
180 |
| - plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") |
181 |
| - except AssertionError: |
182 |
| - _LOGGER.error( |
183 |
| - "Unable to find layer norm plugin, fall back to TensorRT implementation." |
184 |
| - ) |
185 |
| - return layer_norm_no_plugin( |
186 |
| - ctx, target, source_ir, name, input, normalized_shape, weight, bias, eps |
187 |
| - ) |
188 |
| - layer = ctx.net.add_plugin_v2([input], plugin) |
189 |
| - layer.name = name |
190 |
| - return layer.get_output(0) |
191 |
| - |
192 |
| - |
193 |
| -def layer_norm_no_plugin( |
194 |
| - ctx: ConversionContext, |
195 |
| - target: Target, |
196 |
| - source_ir: Optional[SourceIR], |
197 |
| - name: str, |
198 |
| - input: TRTTensor, |
199 |
| - normalized_shape: List[int], |
200 |
| - weight: Optional[Union[torch.Tensor, np.ndarray]], |
201 |
| - bias: Optional[Union[torch.Tensor, np.ndarray]], |
202 |
| - eps: float, |
203 |
| -) -> Union[TRTTensor, Sequence[TRTTensor]]: |
204 |
| - if not isinstance(input, TRTTensor): |
205 |
| - raise RuntimeError( |
206 |
| - f"LayerNorm received input {input} that is not part " |
207 |
| - "of the TensorRT region!" |
208 |
| - ) |
209 |
| - |
210 | 134 | if weight is None:
|
211 | 135 | weight = to_numpy(1.0)
|
212 | 136 |
|
@@ -357,45 +281,180 @@ def group_norm(
|
357 | 281 | eps: float,
|
358 | 282 | cudnn_enabled: bool,
|
359 | 283 | ) -> Union[TRTTensor, Sequence[TRTTensor]]:
|
360 |
| - if not isinstance(input, trt.tensorrt.ITensor): |
361 |
| - raise RuntimeError( |
362 |
| - f"LayerNorm received input {input} that is not part " |
363 |
| - "of the TensorRT region!" |
364 |
| - ) |
365 |
| - |
366 | 284 | if weight is None:
|
367 | 285 | weight = to_numpy(1.0)
|
368 | 286 |
|
369 | 287 | if bias is None:
|
370 | 288 | bias = to_numpy(0.0)
|
371 | 289 |
|
372 |
| - scale = get_trt_tensor(network, weight, "scale") |
373 |
| - bias = get_trt_tensor(network, bias, "bias") |
| 290 | + assert ( |
| 291 | + len(input.shape) >= 3 |
| 292 | + ), f"The input dimension should not be less than 3, got {len(input.shape)}!" |
| 293 | + B, C = input.shape[0], input.shape[1] |
374 | 294 |
|
375 |
| - eps_field = trt.PluginField( |
376 |
| - "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 |
| 295 | + # Groups are a subdivision of the channel dimension. |
| 296 | + assert ( |
| 297 | + C % num_groups == 0 |
| 298 | + ), f"The num of channels ({C}) should be divisible by num_groups ({num_groups})!" |
| 299 | + |
| 300 | + # Normalize every group. |
| 301 | + reshaped_input = impl.shuffle.reshape( |
| 302 | + network, |
| 303 | + target, |
| 304 | + SourceIR.ATEN, |
| 305 | + name, |
| 306 | + input, |
| 307 | + shape=(B * num_groups, -1), |
377 | 308 | )
|
378 |
| - num_groups_filed = trt.PluginField( |
379 |
| - "num_groups", np.array(num_groups), trt.PluginFieldType.INT32 |
| 309 | + dim = ( |
| 310 | + len(reshaped_input.shape) - 1 |
| 311 | + ) # TODO: PR #2347 supported negtive dimension in reduce, could be -1 |
| 312 | + |
| 313 | + # E[X] |
| 314 | + mean_trt = impl.reduce.mean( |
| 315 | + network, |
| 316 | + target, |
| 317 | + SourceIR.ATEN, |
| 318 | + f"{name}_mean", |
| 319 | + reshaped_input, |
| 320 | + dim=dim, |
| 321 | + keepdim=True, |
380 | 322 | )
|
381 | 323 |
|
382 |
| - field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed]) |
| 324 | + # X - E[X] |
| 325 | + sub_trt = impl.elementwise.sub( |
| 326 | + network, |
| 327 | + target, |
| 328 | + source_ir, |
| 329 | + f"{name}_sub", |
| 330 | + reshaped_input, |
| 331 | + mean_trt, |
| 332 | + ) |
383 | 333 |
|
384 |
| - try: |
385 |
| - # Here's the schema of the plugin: |
386 |
| - # https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml |
387 |
| - plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1") |
388 |
| - except AssertionError: |
389 |
| - _LOGGER.error( |
390 |
| - "Unable to find group norm plugin, fall back to TensorRT implementation." |
391 |
| - ) |
| 334 | + # variance = mean(pow(sub_trt, 2)) |
| 335 | + pow_layer = network.add_constant( |
| 336 | + (1,) * len(sub_trt.shape), |
| 337 | + trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), |
| 338 | + ) |
| 339 | + pow_layer.name = f"{name}_power" |
| 340 | + |
| 341 | + pow_var = impl.elementwise.pow( |
| 342 | + network, |
| 343 | + target, |
| 344 | + source_ir, |
| 345 | + f"{name}_pow", |
| 346 | + sub_trt, |
| 347 | + pow_layer.get_output(0), |
| 348 | + ) |
| 349 | + |
| 350 | + var_trt = impl.reduce.mean( |
| 351 | + network, |
| 352 | + target, |
| 353 | + SourceIR.ATEN, |
| 354 | + f"{name}_mean_var", |
| 355 | + pow_var, |
| 356 | + dim=dim, |
| 357 | + keepdim=True, |
| 358 | + ) |
| 359 | + |
| 360 | + # sqrt((var + eps)) |
| 361 | + eps_layer = network.add_constant( |
| 362 | + (1,) * len(reshaped_input.shape), |
| 363 | + trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), |
| 364 | + ) |
| 365 | + eps_layer.name = f"{name}_eps" |
| 366 | + |
| 367 | + add_trt = impl.elementwise.add( |
| 368 | + network, |
| 369 | + target, |
| 370 | + source_ir, |
| 371 | + f"{name}_add", |
| 372 | + var_trt, |
| 373 | + eps_layer.get_output(0), |
| 374 | + ) |
| 375 | + sqrt_trt = impl.unary.sqrt( |
| 376 | + network, |
| 377 | + target, |
| 378 | + source_ir, |
| 379 | + f"{name}_sqrt", |
| 380 | + add_trt, |
| 381 | + ) |
| 382 | + |
| 383 | + # (X - E[X]) / sqrt((var + eps)) |
| 384 | + div_trt = impl.elementwise.div( |
| 385 | + network, |
| 386 | + target, |
| 387 | + source_ir, |
| 388 | + f"{name}_div", |
| 389 | + sub_trt, |
| 390 | + sqrt_trt, |
| 391 | + ) |
| 392 | + |
| 393 | + # Apply per-channel scale and bias. |
| 394 | + output = impl.shuffle.reshape( |
| 395 | + network, |
| 396 | + target, |
| 397 | + SourceIR.ATEN, |
| 398 | + f"{name}_reshape_div", |
| 399 | + div_trt, |
| 400 | + shape=input.shape, |
| 401 | + ) |
| 402 | + |
| 403 | + weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2) |
| 404 | + |
| 405 | + reshaped_weight = impl.shuffle.reshape( |
| 406 | + network, |
| 407 | + target, |
| 408 | + SourceIR.ATEN, |
| 409 | + f"{name}_reshape_weight", |
| 410 | + weight, |
| 411 | + shape=weight_bias_shape, |
| 412 | + ) |
| 413 | + |
| 414 | + output = impl.elementwise.mul( |
| 415 | + network, |
| 416 | + target, |
| 417 | + SourceIR.ATEN, |
| 418 | + f"{name}_mul_scale", |
| 419 | + output, |
| 420 | + reshaped_weight, |
| 421 | + ) |
392 | 422 |
|
393 |
| - layer = network.add_plugin_v2([input, scale, bias], plugin) |
394 |
| - set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir) |
| 423 | + reshaped_bias = impl.shuffle.reshape( |
| 424 | + network, |
| 425 | + target, |
| 426 | + SourceIR.ATEN, |
| 427 | + f"{name}_reshape_bias", |
| 428 | + bias, |
| 429 | + shape=weight_bias_shape, |
| 430 | + ) |
| 431 | + |
| 432 | + add_trt = impl.elementwise.add( |
| 433 | + network, |
| 434 | + target, |
| 435 | + source_ir, |
| 436 | + f"{name}_add_bias", |
| 437 | + output, |
| 438 | + reshaped_bias, |
| 439 | + ) |
| 440 | + |
| 441 | + # TODO: compute the last two return values |
| 442 | + # const1_layer = network.add_constant( |
| 443 | + # (1,) * len(sqrt_trt.shape), |
| 444 | + # trt.Weights(np.ascontiguousarray([1.0], dtype=np.float32)), |
| 445 | + # ) |
| 446 | + # const1_layer.name = f"{name}_const1" |
| 447 | + |
| 448 | + # rsqrt_trt = impl.elementwise.div( |
| 449 | + # network, |
| 450 | + # target, |
| 451 | + # source_ir, |
| 452 | + # f"{name}_rsqrt", |
| 453 | + # const1_layer.get_output(0), |
| 454 | + # sqrt_trt, |
| 455 | + # ) |
395 | 456 |
|
396 |
| - # PyTorch requires three return values: (out, mean, rstd) |
397 |
| - dummy_tensor = torch.tensor(0) |
398 |
| - return layer.get_output(0), dummy_tensor, dummy_tensor |
| 457 | + return add_trt, torch.tensor(0), torch.tensor(0) |
399 | 458 |
|
400 | 459 |
|
401 | 460 | def softmax(
|
|
0 commit comments