Skip to content

Commit a8e7c98

Browse files
committed
Revert "Require less alignment for attn bias (pytorch#114173) (pytorch#114837)"
This reverts commit 5965649.
1 parent 448700d commit a8e7c98

File tree

8 files changed

+43
-192
lines changed

8 files changed

+43
-192
lines changed

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -590,14 +590,9 @@ c10::optional<Tensor> convert_boolean_attn_mask(const c10::optional<Tensor>& att
590590
// We apply this function to the top level SDPA so that
591591
// if padding is done it will be tracked for backward automatically
592592

593-
template<int alignment>
594-
bool aligned_tensor(const at::Tensor& tensor){
595-
for(const auto i : c10::irange(tensor.dim() - 1)){
596-
if(tensor.sym_stride(i) % alignment != 0){
597-
return false;
598-
}
599-
}
600-
return tensor.sym_stride(-1) == 1;
593+
template <int alignment>
594+
bool is_aligned(const SymInt& size){
595+
return size % alignment == 0;
601596
}
602597

603598
template <int alignment>
@@ -613,16 +608,31 @@ at::Tensor preprocess_mask(
613608
const at::Tensor& query,
614609
const at::Tensor& key,
615610
const at::Tensor& value) {
616-
constexpr int mem_eff_alignment = 8;
617-
at::Tensor result_mask = mask;
618-
if (!aligned_tensor<mem_eff_alignment>(mask)) {
619-
result_mask = pad_bias<mem_eff_alignment>(mask);
620-
}
621-
return result_mask.expand_symint(
611+
constexpr int mem_eff_alignment = 16;
612+
// Expand to 4d case
613+
at::Tensor attn_mask = mask.expand_symint(
622614
{query.sym_size(0),
623615
query.sym_size(1),
624616
query.sym_size(2),
625617
key.sym_size(2)});
618+
619+
bool aligned_last_dim = is_aligned<mem_eff_alignment>(attn_mask.sym_size(-1));
620+
// Apply pad_bias and store the result in attn_mask
621+
if (!aligned_last_dim) {
622+
return pad_bias<mem_eff_alignment>(attn_mask);
623+
}
624+
// Check and make the tensor contiguous if needed
625+
auto needs_contig = [](const c10::SymInt& stride) {
626+
return (stride % 16 != 0) || (stride == 0);
627+
};
628+
if (needs_contig(attn_mask.sym_stride(0)) ||
629+
needs_contig(attn_mask.sym_stride(1)) ||
630+
needs_contig(attn_mask.sym_stride(2)) ||
631+
needs_contig(attn_mask.sym_stride(3))) {
632+
return attn_mask.contiguous();
633+
}
634+
635+
return attn_mask;
626636
}
627637

628638
} // namespace

aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,7 @@ struct AttentionBackwardKernel {
11971197
"value is not correctly aligned (strideH)");
11981198
TORCH_CHECK(
11991199
p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0,
1200-
"query is not correctly aligned (strideB).");
1200+
"query is not correctly aligned (strideB)");
12011201
TORCH_CHECK(
12021202
p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0,
12031203
"key is not correctly aligned (strideB)");
@@ -1216,19 +1216,13 @@ struct AttentionBackwardKernel {
12161216
if (p.bias_ptr) {
12171217
TORCH_CHECK(
12181218
p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0,
1219-
"attn_bias is not correctly aligned (strideB). ",
1220-
"attn_bias.stride(0) = ", p.bias_strideB, ", and should be a "
1221-
"multiple of ", kMinimumAlignment, ".");
1219+
"attn_bias is not correctly aligned (strideB)");
12221220
TORCH_CHECK(
12231221
p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0,
1224-
"attn_bias is not correctly aligned (strideH) ."
1225-
"attn_bias.stride(1) = ", p.bias_strideH, ", and should be a "
1226-
"multiple of ", kMinimumAlignment, ".");
1222+
"attn_bias is not correctly aligned (strideH)");
12271223
TORCH_CHECK(
1228-
p.num_queries <= 1 || p.bias_strideM % kMinimumAlignment == 0,
1229-
"attn_bias is not correctly aligned (strideM). "
1230-
"attn_bias.stride(2) = ", p.bias_strideM, ", and should be a ",
1231-
"multiple of ", kMinimumAlignment, ".");
1224+
p.bias_strideM % kMinimumAlignment == 0,
1225+
"attn_bias is not correctly aligned (strideM)");
12321226
}
12331227
if (p.grad_bias_ptr) {
12341228
TORCH_CHECK(

aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -578,19 +578,13 @@ struct AttentionKernel {
578578
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
579579
TORCH_CHECK(
580580
p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0,
581-
"attn_bias is not correctly aligned (strideB). ",
582-
"attn_bias.stride( 0) = ", p.bias_strideB, ", and should be a "
583-
"multiple of ", kAlignmentQ, ".");
581+
"attn_bias is not correctly aligned (strideB)");
584582
TORCH_CHECK(
585583
p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0,
586-
"attn_bias is not correctly aligned (strideH). "
587-
"attn_bias.stride(1) = ", p.bias_strideH, ", and should be a "
588-
"multiple of ", kAlignmentQ, ".");
584+
"attn_bias is not correctly aligned (strideH)");
589585
TORCH_CHECK(
590-
p.num_queries <= 1 || p.bias_strideM % kAlignmentQ == 0,
591-
"attn_bias is not correctly aligned (strideM). "
592-
"attn_bias.stride(2) = ", p.bias_strideM, ", and should be a "
593-
"multiple of ", kAlignmentQ, ".");
586+
p.bias_strideM % kAlignmentQ == 0,
587+
"attn_bias is not correctly aligned");
594588
}
595589
TORCH_CHECK(
596590
p.q_strideM % kAlignmentQ == 0,

test/inductor/test_torchinductor.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6441,52 +6441,6 @@ def fn_or(x, y):
64416441
(torch.randn(32), torch.randn(32)),
64426442
)
64436443

6444-
@requires_cuda()
6445-
@unittest.skipIf(
6446-
not PLATFORM_SUPPORTS_FUSED_SDPA,
6447-
"Does not support mem_eff_attention",
6448-
)
6449-
@skipIfRocm
6450-
def test_sdpa_unaligned_mask(self):
6451-
def foo(
6452-
arg0_1: "f32[8, 8, 16, 16]",
6453-
arg1_1: "f32[8, 8, 15, 16]",
6454-
arg2_1: "f32[8, 8, 15, 16]",
6455-
arg3_1: "f32[1, 1, 16, 15]",
6456-
):
6457-
constant_pad_nd: "f32[1, 1, 16, 16]" = (
6458-
torch.ops.aten.constant_pad_nd.default(arg3_1, [0, 1], 0.0)
6459-
)
6460-
arg3_1 = None
6461-
slice_1: "f32[1, 1, 16, 15]" = torch.ops.aten.slice.Tensor(
6462-
constant_pad_nd, -1, 0, 15
6463-
)
6464-
constant_pad_nd = None
6465-
expand: "f32[8, 8, 16, 15]" = torch.ops.aten.expand.default(
6466-
slice_1, [8, 8, 16, 15]
6467-
)
6468-
slice_1 = None
6469-
_scaled_dot_product_efficient_attention = (
6470-
torch.ops.aten._scaled_dot_product_efficient_attention.default(
6471-
arg0_1, arg1_1, arg2_1, expand, False
6472-
)
6473-
)
6474-
arg0_1 = arg1_1 = arg2_1 = expand = None
6475-
getitem: "f32[8, 8, 16, 16]" = _scaled_dot_product_efficient_attention[0]
6476-
_scaled_dot_product_efficient_attention = None
6477-
return (getitem,)
6478-
6479-
query = torch.rand(8, 8, 16, 16, device="cuda")
6480-
key = torch.rand(8, 8, 15, 16, device="cuda")
6481-
value = torch.rand(8, 8, 15, 16, device="cuda")
6482-
bias = torch.rand(1, 1, 16, 15, device="cuda")
6483-
self.common(
6484-
foo,
6485-
(query, key, value, bias),
6486-
atol=0.02,
6487-
rtol=1e4,
6488-
)
6489-
64906444
@skipIfRocm
64916445
def test_conv_with_as_strided(self):
64926446
class Model(nn.Module):

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ def run(*ex, **kwargs):
259259
"test_zero_dim_reductions_dynamic_shapes": TestFailure(
260260
("cpu", "cuda"), is_skip=True
261261
),
262-
"test_sdpa_unaligned_mask_dynamic_shapes": TestFailure(("cpu",), is_skip=True),
263262
#
264263
# The following tests do not support dynamic shapes yet:
265264
#

test/test_transformers.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,7 @@ def test_mem_eff_attention_long_sequence_mask(self, device, dtype):
17601760
out = F.scaled_dot_product_attention(query, key, value, mask)
17611761
out.sum().backward()
17621762

1763+
17631764
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
17641765
@parametrize("type", ["dense", "nested"])
17651766
@parametrize("is_contiguous", [True, False])
@@ -1800,24 +1801,6 @@ def test_scaled_dot_product_attention_fused_kernels(self, device, type: str, is_
18001801

18011802
self.assertEqual(actual[0].contiguous(), math_ref[0].contiguous(), atol=1e-3, rtol=1e-2)
18021803

1803-
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
1804-
def test_mem_eff_attention_non_contig_mask_bug(self, device):
1805-
dtype = torch.float32
1806-
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
1807-
batch, num_heads, head_dim = 1, 16, 128
1808-
seq_len_q, seq_len_kv = 1, 16
1809-
query = make_tensor(batch, seq_len_q, num_heads * head_dim).view(batch, seq_len_q, num_heads, head_dim).transpose(1, 2)
1810-
kv_shape = (batch, seq_len_kv, head_dim)
1811-
key, value = make_tensor(kv_shape).unsqueeze(1), make_tensor(kv_shape).unsqueeze(1)
1812-
key = key.expand(-1, num_heads, -1, -1)
1813-
value = value.expand(-1, num_heads, -1, -1)
1814-
mask = torch.ones((1, 1, seq_len_q, seq_len_kv), device=device, dtype=torch.bool)
1815-
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
1816-
out = F.scaled_dot_product_attention(query, key, value, mask)
1817-
out_no_mask = F.scaled_dot_product_attention(query, key, value, None)
1818-
max_diff = (out - out_no_mask).abs().mean()
1819-
assert max_diff.item() < 1e-9
1820-
18211804
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
18221805
@parametrize("type", ["dense", "nested"])
18231806
@parametrize("is_contiguous", [True, False])

torch/_inductor/lowering.py

Lines changed: 4 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,91 +1866,10 @@ def apply_constraint(arg, fx_arg):
18661866
make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
18671867
make_fallback(aten.grid_sampler_2d_backward, require_dense)
18681868
make_fallback(aten.randperm)
1869-
1870-
1871-
def sdpa_constraint(fx_node, *args, **kwargs):
1872-
# sdpa requires dense last dimension
1873-
def apply_constraint(arg, fx_arg):
1874-
if not isinstance(arg, ir.IRNode):
1875-
return arg
1876-
1877-
meta_val = fx_arg.meta["val"]
1878-
if not meta_val.is_cuda:
1879-
return arg
1880-
1881-
stride_order = ir.get_stride_order(meta_val.stride())
1882-
if stride_order and stride_order[-1] != 0:
1883-
# contiguous stride order
1884-
stride_order = list(reversed(range(len(arg.get_size()))))
1885-
1886-
# This is the minimum alignment required by SDPA kernels for attention_bias.
1887-
# This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask
1888-
ALIGNMENT = 8
1889-
1890-
is_backward = fx_node.target in (
1891-
aten._scaled_dot_product_efficient_attention_backward.default,
1892-
aten._scaled_dot_product_flash_attention_backward.default,
1893-
)
1894-
1895-
def is_aligned(x):
1896-
return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0
1897-
1898-
assert isinstance(arg, TensorBox)
1899-
1900-
# This correctly handles the forward case:
1901-
if isinstance(arg.data, (ir.SliceView, ir.ExpandView)):
1902-
if not is_aligned(arg):
1903-
# input is padded, requiring_stride_order will unwrap the view and unpad.
1904-
# Would be nice to be able to require certain padding from inductor ir, nyi
1905-
if is_aligned(arg.unwrap_view()):
1906-
return arg
1907-
1908-
def is_aligned_backward(x):
1909-
aligned_strides = all(
1910-
(V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0
1911-
for i in range(len(x.get_stride()) - 1)
1912-
)
1913-
return (
1914-
V.graph.sizevars.size_hint(x.get_stride()[-1])
1915-
) == 1 and aligned_strides
1916-
1917-
if (
1918-
isinstance(arg.data, ir.StorageBox)
1919-
and arg.data.is_input_buffer()
1920-
and is_backward
1921-
):
1922-
if len(arg.data.get_size()) == 4 and is_aligned_backward(arg):
1923-
return arg
1924-
1925-
return ir.ExternKernel.require_stride_order(arg, stride_order)
1926-
1927-
args = tuple(
1928-
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
1929-
)
1930-
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
1931-
return args, kwargs
1932-
1933-
1934-
make_fallback(
1935-
aten._scaled_dot_product_efficient_attention,
1936-
sdpa_constraint,
1937-
warn=False,
1938-
)
1939-
make_fallback(
1940-
aten._scaled_dot_product_efficient_attention_backward,
1941-
sdpa_constraint,
1942-
warn=False,
1943-
)
1944-
make_fallback(
1945-
aten._scaled_dot_product_flash_attention,
1946-
sdpa_constraint,
1947-
warn=False,
1948-
)
1949-
make_fallback(
1950-
aten._scaled_dot_product_flash_attention_backward,
1951-
sdpa_constraint,
1952-
warn=False,
1953-
)
1869+
make_fallback(aten._scaled_dot_product_efficient_attention)
1870+
make_fallback(aten._scaled_dot_product_efficient_attention_backward)
1871+
make_fallback(aten._scaled_dot_product_flash_attention)
1872+
make_fallback(aten._scaled_dot_product_flash_attention_backward)
19541873
make_fallback(aten.sort)
19551874
make_fallback(aten.sort.stable)
19561875
make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)

torch/_meta_registrations.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4985,14 +4985,12 @@ def meta__scaled_dot_product_efficient_backward(
49854985
)
49864986
grad_bias = None
49874987
if attn_bias is not None and grad_input_mask[3]:
4988-
lastDim = attn_bias.size(-1)
4989-
lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
4990-
new_sizes = list(attn_bias.size())
4991-
new_sizes[-1] = lastDimAligned
4992-
grad_bias = torch.empty(
4993-
new_sizes, dtype=attn_bias.dtype, device=attn_bias.device
4988+
grad_bias = torch.empty_strided(
4989+
attn_bias.size(),
4990+
attn_bias.stride(),
4991+
dtype=attn_bias.dtype,
4992+
device=attn_bias.device,
49944993
)
4995-
grad_bias = grad_bias[..., :lastDim]
49964994

49974995
return grad_q, grad_k, grad_v, grad_bias
49984996

0 commit comments

Comments
 (0)