Skip to content

Commit b5a89bb

Browse files
huydhnlezcano
andauthored
Fix broadcasting cosine_similarity (pytorch#114795)
* Fix broadcasting cosine_similarity (pytorch#109363) Fixes pytorch#109333 Pull Request resolved: pytorch#109363 Approved by: https://github.com/peterbell10 * The PR incidentally fixes the test by switching from sizes to sym_sizes test_make_fx_symbolic_exhaustive_masked_scatter_cpu_float32 --------- Co-authored-by: lezcano <[email protected]>
1 parent 3f662b6 commit b5a89bb

File tree

4 files changed

+25
-11
lines changed

4 files changed

+25
-11
lines changed

aten/src/ATen/ExpandUtils.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,18 @@ expand_inplace(
187187
// See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
188188
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
189189
expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
190-
if (to_expand1.sizes().equals(to_expand2.sizes())) {
190+
auto s1 = to_expand1.sym_sizes();
191+
auto s2 = to_expand2.sym_sizes();
192+
if (s1.equals(s2)) {
191193
return std::make_tuple(
192194
c10::MaybeOwned<Tensor>::borrowed(to_expand1),
193195
c10::MaybeOwned<Tensor>::borrowed(to_expand2));
194196
}
195197

196-
auto expanded_size =
197-
infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
198+
auto expanded_size = infer_size_symdimvector(s1, s2);
198199
return std::make_tuple(
199-
c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
200-
c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)));
200+
c10::MaybeOwned<Tensor>::owned(to_expand1.expand_symint(expanded_size)),
201+
c10::MaybeOwned<Tensor>::owned(to_expand2.expand_symint(expanded_size)));
201202
}
202203

203204
inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>

aten/src/ATen/native/Distance.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,24 +308,26 @@ Tensor cosine_similarity(const Tensor& x1_, const Tensor& x2_, int64_t dim, doub
308308
// We accept integral types (and bools lol) but vector_norm does not
309309
auto x1_is_int = c10::isIntegralType(x1_.scalar_type(), /*încludeBool=*/true);
310310
auto x2_is_int = c10::isIntegralType(x2_.scalar_type(), /*încludeBool=*/true);
311-
auto x1 = x1_is_int ? x1_.to(commonDtype) : x1_;
312-
auto x2 = x2_is_int ? x2_.to(commonDtype) : x2_;
311+
auto x1_t = x1_is_int ? x1_.to(commonDtype) : x1_;
312+
auto x2_t = x2_is_int ? x2_.to(commonDtype) : x2_;
313+
c10::MaybeOwned<Tensor> x1, x2;
314+
std::tie(x1, x2) = expand_outplace(x1_t, x2_t);
313315

314316

315317
// We want to divide each tensor by its norm first, as it's more numerically stable.
316318
// This keeps the result between -1.0 and 1.0
317319
// We clone them, as we're going to modify them in-place
318320
// This allows the gradients to propagate propertly all the way to x1 and x2
319-
auto x1_norm = at::linalg_vector_norm(x1, 2, /*dim=*/dim, /*keepdim=*/true).clone();
320-
auto x2_norm = at::linalg_vector_norm(x2, 2, /*dim=*/dim, /*keepdim=*/true).clone();
321+
auto x1_norm = at::linalg_vector_norm(*x1, 2, /*dim=*/dim, /*keepdim=*/true).clone();
322+
auto x2_norm = at::linalg_vector_norm(*x2, 2, /*dim=*/dim, /*keepdim=*/true).clone();
321323

322324
{
323325
at::NoGradGuard guard;
324326
x1_norm.clamp_min_(eps);
325327
x2_norm.clamp_min_(eps);
326328
}
327329

328-
return ((x1 / x1_norm) * (x2 / x2_norm)).sum(dim);
330+
return ((*x1 / x1_norm) * (*x2 / x2_norm)).sum(dim);
329331
}
330332

331333
}} // namespace at::native

test/test_nn.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5609,6 +5609,18 @@ def test_cosine_similarity(self):
56095609
out = F.cosine_similarity(input.to(torch.int8), input, dim=-1)
56105610
self.assertEqual(out, 1.)
56115611

5612+
# Check broadcasting #109333
5613+
a = torch.ones(2, 3, dtype=torch.float)
5614+
b = torch.ones(1, 1, dtype=torch.float)
5615+
out = F.cosine_similarity(a, b)
5616+
self.assertEqual(out, torch.ones(2, dtype=torch.float))
5617+
5618+
a = torch.ones(2, 3, dtype=torch.float)
5619+
b = torch.ones(1, dtype=torch.float)
5620+
out = F.cosine_similarity(a, b)
5621+
self.assertEqual(out, torch.ones(2, dtype=torch.float))
5622+
5623+
56125624
def test_grid_sample_error_checking(self):
56135625
input = torch.empty(1, 1, 2, 2)
56145626
grid = torch.empty(1, 1, 1, 2)

test/test_proxy_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1604,7 +1604,6 @@ def f(t):
16041604

16051605
outplace_symbolic_tensor_failures = {
16061606
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
1607-
xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition
16081607
}
16091608

16101609
inplace_symbolic_tensor_failures = {

0 commit comments

Comments
 (0)