Skip to content

Commit 06e544e

Browse files
authored
fix: Update aten.embedding to reflect schema (#2182)
1 parent 63005e0 commit 06e544e

File tree

2 files changed

+7
-39
lines changed

2 files changed

+7
-39
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -101,23 +101,9 @@ def aten_ops_div(
101101
)
102102

103103

104-
def embedding_param_validator(embedding_node: Node) -> bool:
105-
max_norm = args_bounds_check(embedding_node.args, 2)
106-
norm_type = args_bounds_check(embedding_node.args, 3)
107-
scale_grad_by_freq = args_bounds_check(embedding_node.args, 4)
108-
sparse = args_bounds_check(embedding_node.args, 5)
109-
110-
if max_norm is not None:
111-
_LOGGER.debug(
112-
f"Currently we don't support specifying max_norm, got {max_norm}."
113-
)
114-
return False
115-
116-
if norm_type is not None and norm_type != 2.0:
117-
_LOGGER.debug(
118-
f"Currently we don't support specifying norm_type, got {norm_type}."
119-
)
120-
return False
104+
def embedding_param_validator(embedding_node: Node):
105+
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
106+
sparse = args_bounds_check(embedding_node.args, 4)
121107

122108
if scale_grad_by_freq is not None:
123109
_LOGGER.debug(
@@ -149,10 +135,9 @@ def aten_ops_embedding(
149135
name,
150136
input=args[1],
151137
weight=args[0],
152-
max_norm=args_bounds_check(args, 2),
153-
norm_type=args_bounds_check(args, 3),
154-
scale_grad_by_freq=args_bounds_check(args, 4),
155-
sparse=args_bounds_check(args, 5),
138+
# args[2] is the padding index, which is useful for training only
139+
scale_grad_by_freq=args_bounds_check(args, 3),
140+
sparse=args_bounds_check(args, 4),
156141
)
157142

158143

py/torch_tensorrt/dynamo/conversion/impl/embedding.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,9 @@ def embedding(
1414
name: str,
1515
input: TRTTensor,
1616
weight: TRTTensor,
17-
max_norm: None,
18-
norm_type: None,
1917
scale_grad_by_freq: bool,
2018
sparse: bool,
2119
) -> TRTTensor:
22-
if network.has_implicit_batch_dimension:
23-
raise RuntimeError(
24-
"The `embedding` function should be called with explicit batch dimension."
25-
)
26-
2720
indices_tensor = input
2821
embedding_tensor = weight
2922
if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
@@ -37,16 +30,6 @@ def embedding(
3730
# unsupported parameters
3831
# ignore padding_idx since it is meaningful for training only
3932

40-
if max_norm is not None:
41-
raise RuntimeError(
42-
f"Currently we don't support specifying max_norm, got {max_norm}."
43-
)
44-
45-
if norm_type is not None and norm_type != 2.0:
46-
raise RuntimeError(
47-
f"Currently we don't support specifying max_norm, got {norm_type} for norm_type."
48-
)
49-
5033
if scale_grad_by_freq:
5134
raise RuntimeError(
5235
"Currently we don't support scale gradient by word frequency."
@@ -57,5 +40,5 @@ def embedding(
5740

5841
# Implement embedding lookup with gather layer
5942
gather_layer = network.add_gather(embedding_tensor, indices_tensor, axis=0)
60-
set_layer_name(gather_layer, target, name + "_gather")
43+
set_layer_name(gather_layer, target, name + "_gather", source_ir)
6144
return gather_layer.get_output(0)

0 commit comments

Comments
 (0)