Skip to content

Commit 670ff84

Browse files
committed
fix: Update embedding to reflect ATen schema
- Remove arguments not present in initial schema for embedding - Improve coverage of embedding operator by expanding set of convertible implementations - Update parameter-checking function accordingly
1 parent 8c62fca commit 670ff84

File tree

2 files changed

+7
-45
lines changed

2 files changed

+7
-45
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,8 @@ def aten_ops_div(
9696

9797

9898
def embedding_param_validator(embedding_node: Node):
99-
100-
max_norm = args_bounds_check(embedding_node.args, 2)
101-
norm_type = args_bounds_check(embedding_node.args, 3)
102-
scale_grad_by_freq = args_bounds_check(embedding_node.args, 4)
103-
sparse = args_bounds_check(embedding_node.args, 5)
104-
105-
if max_norm is not None:
106-
_LOGGER.debug(
107-
f"Currently we don't support specifying max_norm, got {max_norm}."
108-
)
109-
return False
110-
111-
if norm_type is not None and norm_type != 2.0:
112-
_LOGGER.debug(
113-
f"Currently we don't support specifying norm_type, got {norm_type}."
114-
)
115-
return False
99+
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
100+
sparse = args_bounds_check(embedding_node.args, 4)
116101

117102
if scale_grad_by_freq is not None:
118103
_LOGGER.debug(
@@ -144,10 +129,9 @@ def aten_ops_embedding(
144129
name,
145130
input=args[1],
146131
weight=args[0],
147-
max_norm=args_bounds_check(args, 2),
148-
norm_type=args_bounds_check(args, 3),
149-
scale_grad_by_freq=args_bounds_check(args, 4),
150-
sparse=args_bounds_check(args, 5),
132+
# args[2] is the padding index, which is useful for training only
133+
scale_grad_by_freq=args_bounds_check(args, 3),
134+
sparse=args_bounds_check(args, 4),
151135
)
152136

153137

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
1-
import operator
2-
import warnings
3-
from typing import Optional, cast, Any
1+
from typing import Optional
42

5-
import numpy as np
6-
7-
import tensorrt as trt
83
import torch
94
from torch.fx.node import Target
105

@@ -24,17 +19,10 @@ def embedding(
2419
name: str,
2520
input: TRTTensor,
2621
weight: TRTTensor,
27-
max_norm: None,
28-
norm_type: None,
2922
scale_grad_by_freq: bool,
3023
sparse: bool,
3124
) -> TRTTensor:
3225

33-
if network.has_implicit_batch_dimension:
34-
raise RuntimeError(
35-
"The `embedding` function should be called with explicit batch dimension."
36-
)
37-
3826
indices_tensor = input
3927
embedding_tensor = weight
4028
if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
@@ -48,16 +36,6 @@ def embedding(
4836
# unsupported parameters
4937
# ignore padding_idx since it is meaningful for training only
5038

51-
if max_norm is not None:
52-
raise RuntimeError(
53-
f"Currently we don't support specifying max_norm, got {max_norm}."
54-
)
55-
56-
if norm_type is not None and norm_type != 2.0:
57-
raise RuntimeError(
58-
f"Currently we don't support specifying max_norm, got {norm_type} for norm_type."
59-
)
60-
6139
if scale_grad_by_freq:
6240
raise RuntimeError(
6341
"Currently we don't support scale gradient by word frequency."
@@ -68,5 +46,5 @@ def embedding(
6846

6947
# Implement embedding lookup with gather layer
7048
gather_layer = network.add_gather(embedding_tensor, indices_tensor, axis=0)
71-
set_layer_name(gather_layer, target, name + "_gather")
49+
set_layer_name(gather_layer, target, name + "_gather", source_ir)
7250
return gather_layer.get_output(0)

0 commit comments

Comments
 (0)