Skip to content

Commit 49858f6

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Allow integral Scalar to be extracted as bool
Summary: In ATen, integral `Scalar`s are allowed to be extracted as `bool`. Example: ``` torch.ops.aten.add.out(torch.Tensor([0, 1]).to(dtype=torch.bool), torch.Tensor([1, 1]).to(dtype=torch.bool), alpha=4, out=torch.zeros(1).to(dtype=torch.bool)) ``` This updates `extract_scalar` for the `bool` case to match this behaviour, where previously it would fail unless the `Scalar` was boolean. Reviewed By: manuelcandales Differential Revision: D46990102 fbshipit-source-id: 6d2875cfafbe2723cbf6855af55ded54c154e970
1 parent 40df311 commit 49858f6

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

kernels/portable/cpu/scalar_utils.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,15 @@ template <
176176
typename BOOL_T,
177177
std::enable_if_t<std::is_same_v<BOOL_T, bool>, bool> = true>
178178
bool extract_scalar(Scalar scalar, BOOL_T* out_val) {
179-
if (!scalar.isBoolean()) {
180-
return false;
179+
if (scalar.isIntegral(false)) {
180+
*out_val = static_cast<bool>(scalar.to<int64_t>());
181+
return true;
181182
}
182-
*out_val = scalar.to<bool>();
183-
return true;
183+
if (scalar.isBoolean()) {
184+
*out_val = scalar.to<bool>();
185+
return true;
186+
}
187+
return false;
184188
}
185189

186190
} // namespace utils

0 commit comments

Comments
 (0)