Skip to content

Commit 5e6839d

Browse files
fzimmermann89facebook-github-bot
authored andcommitted
Tighten type hints for tensor arithmetic (#2550)
Summary: Pull Request resolved: #2550 X-link: pytorch/executorch#6752 Fixes pytorch/pytorch#124015 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov X-link: pytorch/pytorch#135392 Reviewed By: bobrenjc93, malfet Differential Revision: D65753120 Pulled By: ezyang fbshipit-source-id: 28ae18661b71078fa1dfdf8c60d1a75bec2f3e39
1 parent eccb5b3 commit 5e6839d

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torchrec/metrics/recall_session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,13 @@ def get_recall_states(
205205
) -> Dict[str, torch.Tensor]:
206206

207207
predictions_ranked = ranking_within_session(predictions, session)
208+
# pyre-fixme[58]: `<` is not supported for operand types `Tensor` and
209+
# `Optional[int]`.
208210
predictions_labels = (predictions_ranked < self.top_threshold).to(torch.int32)
209211
if self.run_ranking_of_labels:
210212
labels_ranked = ranking_within_session(labels, session)
213+
# pyre-fixme[58]: `<` is not supported for operand types `Tensor` and
214+
# `Optional[int]`.
211215
labels = (labels_ranked < self.top_threshold).to(torch.int32)
212216
num_true_pos = _calc_num_true_pos(labels, predictions_labels, weights)
213217
num_false_neg = _calc_num_false_neg(labels, predictions_labels, weights)

0 commit comments

Comments
 (0)