Skip to content

Commit eccb5b3

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
Fix inference output dtype, use FP16 (#2537)
Summary: Pull Request resolved: #2537 This diff supports specifying the output type of the TBE during model processing in AIMP with TorchRec eager mode. Part of the ~30% QPS gain optimization for SNN on APS. Reviewed By: ZhengkaiZ Differential Revision: D65445160 fbshipit-source-id: d16226c1856486916e83192fb79730641f70fc7c
1 parent 936998d commit eccb5b3

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

torchrec/inference/modules.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,6 @@ def _quantize_fp_module(
418418
model: torch.nn.Module,
419419
fp_module: FeatureProcessedEmbeddingBagCollection,
420420
fp_module_fqn: str,
421-
activation_dtype: torch.dtype = torch.float,
422421
weight_dtype: torch.dtype = DEFAULT_QUANTIZATION_DTYPE,
423422
per_fp_table_weight_dtype: Optional[Dict[str, torch.dtype]] = None,
424423
) -> None:
@@ -428,7 +427,7 @@ def _quantize_fp_module(
428427

429428
quant_prep_enable_register_tbes(model, [FeatureProcessedEmbeddingBagCollection])
430429
fp_module.qconfig = QuantConfig(
431-
activation=quant.PlaceholderObserver.with_args(dtype=activation_dtype),
430+
activation=quant.PlaceholderObserver.with_args(dtype=output_dtype),
432431
weight=quant.PlaceholderObserver.with_args(dtype=weight_dtype),
433432
per_table_weight_dtype=per_fp_table_weight_dtype,
434433
)

0 commit comments

Comments
 (0)