|
10 | 10 | from torch_tensorrt.dynamo.conversion.converter_utils import get_axes_for_reduce_op
|
11 | 11 | from torch_tensorrt.fx.converters.converter_utils import (
|
12 | 12 | get_positive_dim,
|
13 |
| - get_trt_plugin, |
14 |
| - get_trt_tensor, |
15 | 13 | has_dynamic_shape,
|
16 | 14 | set_layer_name,
|
17 | 15 | to_numpy,
|
@@ -58,10 +56,7 @@ def batch_norm(
|
58 | 56 | if running_var is None:
|
59 | 57 | running_var = 1.0
|
60 | 58 |
|
61 |
| - scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt( |
62 |
| - cast(torch.Tensor, to_numpy(running_var)) + eps |
63 |
| - ) |
64 |
| - |
| 59 | + scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps) |
65 | 60 | bias = to_numpy(bias) - to_numpy(running_mean) * scale
|
66 | 61 | power = np.ones_like(scale)
|
67 | 62 |
|
@@ -107,78 +102,6 @@ def layer_norm(
|
107 | 102 | eps: float,
|
108 | 103 | cudnn_enable: bool,
|
109 | 104 | ) -> Union[TRTTensor, Sequence[TRTTensor]]:
|
110 |
| - if not isinstance(input, trt.tensorrt.ITensor): |
111 |
| - raise RuntimeError( |
112 |
| - f"LayerNorm received input {input} that is not part " |
113 |
| - "of the TensorRT region!" |
114 |
| - ) |
115 |
| - |
116 |
| - if weight is None: |
117 |
| - weight = to_numpy(1.0) |
118 |
| - |
119 |
| - if bias is None: |
120 |
| - bias = to_numpy(0.0) |
121 |
| - |
122 |
| - gamma = ( |
123 |
| - weight.detach().cpu().float().numpy() |
124 |
| - if isinstance(weight, torch.Tensor) |
125 |
| - else weight |
126 |
| - ) |
127 |
| - gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32) |
128 |
| - beta = ( |
129 |
| - bias.detach().cpu().float().numpy() if isinstance(bias, torch.Tensor) else bias |
130 |
| - ) |
131 |
| - beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32) |
132 |
| - eps_field = trt.PluginField( |
133 |
| - "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 |
134 |
| - ) |
135 |
| - try: |
136 |
| - normalized_shape_arr = np.array(normalized_shape, dtype=np.int32) |
137 |
| - except TypeError: |
138 |
| - _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") |
139 |
| - normalized_shape_arr = np.array([], dtype=np.int32) |
140 |
| - |
141 |
| - normalized_shape_filed = trt.PluginField( |
142 |
| - "normalized_shape", normalized_shape_arr, trt.PluginFieldType.INT32 |
143 |
| - ) |
144 |
| - field_collection = trt.PluginFieldCollection( |
145 |
| - [gamma_field, beta_field, eps_field, normalized_shape_filed] |
146 |
| - ) |
147 |
| - |
148 |
| - try: |
149 |
| - if network.has_implicit_batch_dimension: |
150 |
| - plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt") |
151 |
| - else: |
152 |
| - plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") |
153 |
| - except AssertionError: |
154 |
| - _LOGGER.error( |
155 |
| - "Unable to find layer norm plugin, fall back to TensorRT implementation." |
156 |
| - ) |
157 |
| - return layer_norm_no_plugin( |
158 |
| - network, target, source_ir, name, input, normalized_shape, weight, bias, eps |
159 |
| - ) |
160 |
| - layer = network.add_plugin_v2([input], plugin) |
161 |
| - layer.name = name |
162 |
| - return layer.get_output(0) |
163 |
| - |
164 |
| - |
165 |
| -def layer_norm_no_plugin( |
166 |
| - network: TRTNetwork, |
167 |
| - target: Target, |
168 |
| - source_ir: Optional[SourceIR], |
169 |
| - name: str, |
170 |
| - input: TRTTensor, |
171 |
| - normalized_shape: List[int], |
172 |
| - weight: Optional[Union[torch.Tensor, np.ndarray]], |
173 |
| - bias: Optional[Union[torch.Tensor, np.ndarray]], |
174 |
| - eps: float, |
175 |
| -) -> Union[TRTTensor, Sequence[TRTTensor]]: |
176 |
| - if not isinstance(input, TRTTensor): |
177 |
| - raise RuntimeError( |
178 |
| - f"LayerNorm received input {input} that is not part " |
179 |
| - "of the TensorRT region!" |
180 |
| - ) |
181 |
| - |
182 | 105 | if weight is None:
|
183 | 106 | weight = to_numpy(1.0)
|
184 | 107 |
|
@@ -333,45 +256,180 @@ def group_norm(
|
333 | 256 | eps: float,
|
334 | 257 | cudnn_enabled: bool,
|
335 | 258 | ) -> Union[TRTTensor, Sequence[TRTTensor]]:
|
336 |
| - if not isinstance(input, trt.tensorrt.ITensor): |
337 |
| - raise RuntimeError( |
338 |
| - f"LayerNorm received input {input} that is not part " |
339 |
| - "of the TensorRT region!" |
340 |
| - ) |
341 |
| - |
342 | 259 | if weight is None:
|
343 | 260 | weight = to_numpy(1.0)
|
344 | 261 |
|
345 | 262 | if bias is None:
|
346 | 263 | bias = to_numpy(0.0)
|
347 | 264 |
|
348 |
| - scale = get_trt_tensor(network, weight, "scale") |
349 |
| - bias = get_trt_tensor(network, bias, "bias") |
| 265 | + assert ( |
| 266 | + len(input.shape) >= 3 |
| 267 | + ), f"The input dimension should not be less than 3, got {len(input.shape)}!" |
| 268 | + B, C = input.shape[0], input.shape[1] |
350 | 269 |
|
351 |
| - eps_field = trt.PluginField( |
352 |
| - "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 |
| 270 | + # Groups are a subdivision of the channel dimension. |
| 271 | + assert ( |
| 272 | + C % num_groups == 0 |
| 273 | + ), f"The num of channels ({C}) should be divisible by num_groups ({num_groups})!" |
| 274 | + |
| 275 | + # Normalize every group. |
| 276 | + reshaped_input = impl.shuffle.reshape( |
| 277 | + network, |
| 278 | + target, |
| 279 | + SourceIR.ATEN, |
| 280 | + name, |
| 281 | + input, |
| 282 | + shape=(B * num_groups, -1), |
353 | 283 | )
|
354 |
| - num_groups_filed = trt.PluginField( |
355 |
| - "num_groups", np.array(num_groups), trt.PluginFieldType.INT32 |
| 284 | + dim = ( |
| 285 | + len(reshaped_input.shape) - 1 |
| 286 | + ) # TODO: PR #2347 supported negtive dimension in reduce, could be -1 |
| 287 | + |
| 288 | + # E[X] |
| 289 | + mean_trt = impl.reduce.mean( |
| 290 | + network, |
| 291 | + target, |
| 292 | + SourceIR.ATEN, |
| 293 | + f"{name}_mean", |
| 294 | + reshaped_input, |
| 295 | + dim=dim, |
| 296 | + keepdim=True, |
356 | 297 | )
|
357 | 298 |
|
358 |
| - field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed]) |
| 299 | + # X - E[X] |
| 300 | + sub_trt = impl.elementwise.sub( |
| 301 | + network, |
| 302 | + target, |
| 303 | + source_ir, |
| 304 | + f"{name}_sub", |
| 305 | + reshaped_input, |
| 306 | + mean_trt, |
| 307 | + ) |
359 | 308 |
|
360 |
| - try: |
361 |
| - # Here's the schema of the plugin: |
362 |
| - # https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml |
363 |
| - plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1") |
364 |
| - except AssertionError: |
365 |
| - _LOGGER.error( |
366 |
| - "Unable to find group norm plugin, fall back to TensorRT implementation." |
367 |
| - ) |
| 309 | + # variance = mean(pow(sub_trt, 2)) |
| 310 | + pow_layer = network.add_constant( |
| 311 | + (1,) * len(sub_trt.shape), |
| 312 | + trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), |
| 313 | + ) |
| 314 | + pow_layer.name = f"{name}_power" |
368 | 315 |
|
369 |
| - layer = network.add_plugin_v2([input, scale, bias], plugin) |
370 |
| - set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir) |
| 316 | + pow_var = impl.elementwise.pow( |
| 317 | + network, |
| 318 | + target, |
| 319 | + source_ir, |
| 320 | + f"{name}_pow", |
| 321 | + sub_trt, |
| 322 | + pow_layer.get_output(0), |
| 323 | + ) |
| 324 | + |
| 325 | + var_trt = impl.reduce.mean( |
| 326 | + network, |
| 327 | + target, |
| 328 | + SourceIR.ATEN, |
| 329 | + f"{name}_mean_var", |
| 330 | + pow_var, |
| 331 | + dim=dim, |
| 332 | + keepdim=True, |
| 333 | + ) |
| 334 | + |
| 335 | + # sqrt((var + eps)) |
| 336 | + eps_layer = network.add_constant( |
| 337 | + (1,) * len(reshaped_input.shape), |
| 338 | + trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), |
| 339 | + ) |
| 340 | + eps_layer.name = f"{name}_eps" |
| 341 | + |
| 342 | + add_trt = impl.elementwise.add( |
| 343 | + network, |
| 344 | + target, |
| 345 | + source_ir, |
| 346 | + f"{name}_add", |
| 347 | + var_trt, |
| 348 | + eps_layer.get_output(0), |
| 349 | + ) |
| 350 | + sqrt_trt = impl.unary.sqrt( |
| 351 | + network, |
| 352 | + target, |
| 353 | + source_ir, |
| 354 | + f"{name}_sqrt", |
| 355 | + add_trt, |
| 356 | + ) |
| 357 | + |
| 358 | + # (X - E[X]) / sqrt((var + eps)) |
| 359 | + div_trt = impl.elementwise.div( |
| 360 | + network, |
| 361 | + target, |
| 362 | + source_ir, |
| 363 | + f"{name}_div", |
| 364 | + sub_trt, |
| 365 | + sqrt_trt, |
| 366 | + ) |
| 367 | + |
| 368 | + # Apply per-channel scale and bias. |
| 369 | + output = impl.shuffle.reshape( |
| 370 | + network, |
| 371 | + target, |
| 372 | + SourceIR.ATEN, |
| 373 | + f"{name}_reshape_div", |
| 374 | + div_trt, |
| 375 | + shape=input.shape, |
| 376 | + ) |
| 377 | + |
| 378 | + weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2) |
| 379 | + |
| 380 | + reshaped_weight = impl.shuffle.reshape( |
| 381 | + network, |
| 382 | + target, |
| 383 | + SourceIR.ATEN, |
| 384 | + f"{name}_reshape_weight", |
| 385 | + weight, |
| 386 | + shape=weight_bias_shape, |
| 387 | + ) |
| 388 | + |
| 389 | + output = impl.elementwise.mul( |
| 390 | + network, |
| 391 | + target, |
| 392 | + SourceIR.ATEN, |
| 393 | + f"{name}_mul_scale", |
| 394 | + output, |
| 395 | + reshaped_weight, |
| 396 | + ) |
| 397 | + |
| 398 | + reshaped_bias = impl.shuffle.reshape( |
| 399 | + network, |
| 400 | + target, |
| 401 | + SourceIR.ATEN, |
| 402 | + f"{name}_reshape_bias", |
| 403 | + bias, |
| 404 | + shape=weight_bias_shape, |
| 405 | + ) |
| 406 | + |
| 407 | + add_trt = impl.elementwise.add( |
| 408 | + network, |
| 409 | + target, |
| 410 | + source_ir, |
| 411 | + f"{name}_add_bias", |
| 412 | + output, |
| 413 | + reshaped_bias, |
| 414 | + ) |
371 | 415 |
|
372 |
| - # PyTorch requires three return values: (out, mean, rstd) |
373 |
| - dummy_tensor = torch.tensor(0) |
374 |
| - return layer.get_output(0), dummy_tensor, dummy_tensor |
| 416 | + # TODO: compute the last two return values |
| 417 | + # const1_layer = network.add_constant( |
| 418 | + # (1,) * len(sqrt_trt.shape), |
| 419 | + # trt.Weights(np.ascontiguousarray([1.0], dtype=np.float32)), |
| 420 | + # ) |
| 421 | + # const1_layer.name = f"{name}_const1" |
| 422 | + |
| 423 | + # rsqrt_trt = impl.elementwise.div( |
| 424 | + # network, |
| 425 | + # target, |
| 426 | + # source_ir, |
| 427 | + # f"{name}_rsqrt", |
| 428 | + # const1_layer.get_output(0), |
| 429 | + # sqrt_trt, |
| 430 | + # ) |
| 431 | + |
| 432 | + return add_trt, torch.tensor(0), torch.tensor(0) |
375 | 433 |
|
376 | 434 |
|
377 | 435 | def softmax(
|
|
0 commit comments