Skip to content

Commit e023d8f

Browse files
salilsdesaifacebook-github-bot
authored andcommitted
Add pattern + replacement for Embedding with padding_idx
Summary: This diff adds a pattern/replacement for embedding with padding_idx, which causes embedding in the NLU model to be quantized successfully. Previously, the embedding op in the NLU model was not being quantized. This was happening because embedding in NLU includes an extra arg, padding_idx, which was not expected by the pattern used to match embedding ops for replacement in model graphs. This change also reduces the size of the NLU model from 11.4 MB to 4.4 MB since embedding weight tensors are stored in quantized form instead of fp32. Reviewed By: digantdesai, mcr229 Differential Revision: D48191947 fbshipit-source-id: 47283aa8c4990325238c362d130d7e2d141fcf0f
1 parent 0b317e5 commit e023d8f

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,12 +478,58 @@ def replacement(
478478
)
479479
return out
480480

481+
@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
482+
def pattern_with_padding_idx(
483+
weight,
484+
weight_scales,
485+
weight_zero_points,
486+
weight_quant_min,
487+
weight_quant_max,
488+
indicies,
489+
padding_idx,
490+
):
491+
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
492+
weight,
493+
weight_scales,
494+
weight_zero_points,
495+
0,
496+
weight_quant_min,
497+
weight_quant_max,
498+
torch.uint8,
499+
)
500+
out = torch.ops.aten.embedding.default(weight, indicies, padding_idx)
501+
return out
502+
503+
def replacement_with_padding_idx(
504+
weight,
505+
weight_scales,
506+
weight_zero_points,
507+
weight_quant_min,
508+
weight_quant_max,
509+
indicies,
510+
_, # padding_idx only matters for training and not when running op for inference
511+
):
512+
out = torch.ops.quantized_decomposed.embedding_byte.default(
513+
weight,
514+
weight_scales,
515+
weight_zero_points,
516+
weight_quant_min,
517+
weight_quant_max,
518+
indicies,
519+
)
520+
return out
521+
481522
return [
482523
(
483524
_trace_and_lower_to_edge_ops(pattern),
484525
_trace_and_lower_to_edge_ops(replacement),
485526
[],
486-
)
527+
),
528+
(
529+
_trace_and_lower_to_edge_ops(pattern_with_padding_idx),
530+
_trace_and_lower_to_edge_ops(replacement_with_padding_idx),
531+
[],
532+
),
487533
]
488534

489535
patterns_and_replacements = []

0 commit comments

Comments
 (0)