Skip to content

Commit bc32ea4

Browse files
authored
Fix memory overlap issue in copy_ of linear_kernel_output (#2627)
* Fix memory overlap issue in copy_ of linear_kernel * replace is_same with data_ptr check
1 parent fac2423 commit bc32ea4

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

csrc/cpu/aten/Linear.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ void linear_kernel_output(
8787
if (self.dim() != 2) {
8888
output_ = output_.reshape(output_size);
8989
}
90-
if (!out_is_contiguous || !output.is_same(output_)) {
90+
if (!out_is_contiguous || output.data_ptr() != output_.data_ptr()) {
9191
output.copy_(output_);
9292
}
9393
}

tests/cpu/test_jit.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,18 @@ def forward(self, x):
761761
return torch.add(self.linear(x), self.linear1(x1))
762762

763763

764+
class LinearAdd2(nn.Module):
765+
def __init__(self, in_channels, out_channels, **kwargs):
766+
super(LinearAdd2, self).__init__()
767+
seed = 2018
768+
torch.manual_seed(seed)
769+
self.linear = nn.Linear(in_channels, out_channels, **kwargs)
770+
771+
def forward(self, x):
772+
y = x.clone().unsqueeze(0).permute(2, 1, 0, 3).squeeze(0)
773+
return self.linear(x) + y
774+
775+
764776
class LinearAddRelu(nn.Module):
765777
def __init__(self, in_channels, mid_channels, out_channels, inplace, **kwargs):
766778
super(LinearAddRelu, self).__init__()
@@ -4480,6 +4492,11 @@ def test_output_linear_add(self):
44804492
torch.rand(32, 3),
44814493
kind_in_graph="ipex_prepack::linear_add_run",
44824494
)
4495+
self._test_dnnl_fp32(
4496+
LinearAdd2(3, 3, bias=False),
4497+
torch.rand(3, 1, 3),
4498+
kind_in_graph="ipex_prepack::linear_add_run",
4499+
)
44834500
self._test_output_lowp(
44844501
LinearAdd(3, 32, bias=True),
44854502
torch.rand(32, 3),

0 commit comments

Comments
 (0)