Skip to content

fix: Update aten.embedding to reflect schema #2182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 6 additions & 21 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,9 @@ def aten_ops_div(
)


def embedding_param_validator(embedding_node: Node) -> bool:
max_norm = args_bounds_check(embedding_node.args, 2)
norm_type = args_bounds_check(embedding_node.args, 3)
scale_grad_by_freq = args_bounds_check(embedding_node.args, 4)
sparse = args_bounds_check(embedding_node.args, 5)

if max_norm is not None:
_LOGGER.debug(
f"Currently we don't support specifying max_norm, got {max_norm}."
)
return False

if norm_type is not None and norm_type != 2.0:
_LOGGER.debug(
f"Currently we don't support specifying norm_type, got {norm_type}."
)
return False
def embedding_param_validator(embedding_node: Node):
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
sparse = args_bounds_check(embedding_node.args, 4)

if scale_grad_by_freq is not None:
_LOGGER.debug(
Expand Down Expand Up @@ -149,10 +135,9 @@ def aten_ops_embedding(
name,
input=args[1],
weight=args[0],
max_norm=args_bounds_check(args, 2),
norm_type=args_bounds_check(args, 3),
scale_grad_by_freq=args_bounds_check(args, 4),
sparse=args_bounds_check(args, 5),
# args[2] is the padding index, which is useful for training only
scale_grad_by_freq=args_bounds_check(args, 3),
sparse=args_bounds_check(args, 4),
)


Expand Down
19 changes: 1 addition & 18 deletions py/torch_tensorrt/dynamo/conversion/impl/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,9 @@ def embedding(
name: str,
input: TRTTensor,
weight: TRTTensor,
max_norm: None,
norm_type: None,
scale_grad_by_freq: bool,
sparse: bool,
) -> TRTTensor:
if network.has_implicit_batch_dimension:
raise RuntimeError(
"The `embedding` function should be called with explicit batch dimension."
)

indices_tensor = input
embedding_tensor = weight
if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
Expand All @@ -37,16 +30,6 @@ def embedding(
# unsupported parameters
# ignore padding_idx since it is meaningful for training only

if max_norm is not None:
raise RuntimeError(
f"Currently we don't support specifying max_norm, got {max_norm}."
)

if norm_type is not None and norm_type != 2.0:
raise RuntimeError(
f"Currently we don't support specifying max_norm, got {norm_type} for norm_type."
)

if scale_grad_by_freq:
raise RuntimeError(
"Currently we don't support scale gradient by word frequency."
Expand All @@ -57,5 +40,5 @@ def embedding(

# Implement embedding lookup with gather layer
gather_layer = network.add_gather(embedding_tensor, indices_tensor, axis=0)
set_layer_name(gather_layer, target, name + "_gather")
set_layer_name(gather_layer, target, name + "_gather", source_ir)
return gather_layer.get_output(0)