Skip to content

Commit 5965649

Browse files
authored
Require less alignment for attn bias (pytorch#114173) (pytorch#114837)
Improved Fix for Attention Mask Alignment Issue (pytorch#112577) This PR addresses Issue pytorch#112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention. Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users. Should this be warn_once? We only call expand, once on the aligned mask. Reference https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115 @albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki. Pull Request resolved: pytorch#114173 Approved by: https://github.com/danthe3rd
1 parent 41210ea commit 5965649

File tree

8 files changed

+192
-43
lines changed

8 files changed

+192
-43
lines changed

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -590,9 +590,14 @@ c10::optional<Tensor> convert_boolean_attn_mask(const c10::optional<Tensor>& att
590590
// We apply this function to the top level SDPA so that
591591
// if padding is done it will be tracked for backward automatically
592592

593-
template <int alignment>
594-
bool is_aligned(const SymInt& size){
595-
return size % alignment == 0;
593+
template<int alignment>
594+
bool aligned_tensor(const at::Tensor& tensor){
595+
for(const auto i : c10::irange(tensor.dim() - 1)){
596+
if(tensor.sym_stride(i) % alignment != 0){
597+
return false;
598+
}
599+
}
600+
return tensor.sym_stride(-1) == 1;
596601
}
597602

598603
template <int alignment>
@@ -608,31 +613,16 @@ at::Tensor preprocess_mask(
608613
const at::Tensor& query,
609614
const at::Tensor& key,
610615
const at::Tensor& value) {
611-
constexpr int mem_eff_alignment = 16;
612-
// Expand to 4d case
613-
at::Tensor attn_mask = mask.expand_symint(
616+
constexpr int mem_eff_alignment = 8;
617+
at::Tensor result_mask = mask;
618+
if (!aligned_tensor<mem_eff_alignment>(mask)) {
619+
result_mask = pad_bias<mem_eff_alignment>(mask);
620+
}
621+
return result_mask.expand_symint(
614622
{query.sym_size(0),
615623
query.sym_size(1),
616624
query.sym_size(2),
617625
key.sym_size(2)});
618-
619-
bool aligned_last_dim = is_aligned<mem_eff_alignment>(attn_mask.sym_size(-1));
620-
// Apply pad_bias and store the result in attn_mask
621-
if (!aligned_last_dim) {
622-
return pad_bias<mem_eff_alignment>(attn_mask);
623-
}
624-
// Check and make the tensor contiguous if needed
625-
auto needs_contig = [](const c10::SymInt& stride) {
626-
return (stride % 16 != 0) || (stride == 0);
627-
};
628-
if (needs_contig(attn_mask.sym_stride(0)) ||
629-
needs_contig(attn_mask.sym_stride(1)) ||
630-
needs_contig(attn_mask.sym_stride(2)) ||
631-
needs_contig(attn_mask.sym_stride(3))) {
632-
return attn_mask.contiguous();
633-
}
634-
635-
return attn_mask;
636626
}
637627

638628
} // namespace

aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,7 @@ struct AttentionBackwardKernel {
11971197
"value is not correctly aligned (strideH)");
11981198
TORCH_CHECK(
11991199
p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0,
1200-
"query is not correctly aligned (strideB)");
1200+
"query is not correctly aligned (strideB).");
12011201
TORCH_CHECK(
12021202
p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0,
12031203
"key is not correctly aligned (strideB)");
@@ -1216,13 +1216,19 @@ struct AttentionBackwardKernel {
12161216
if (p.bias_ptr) {
12171217
TORCH_CHECK(
12181218
p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0,
1219-
"attn_bias is not correctly aligned (strideB)");
1219+
"attn_bias is not correctly aligned (strideB). ",
1220+
"attn_bias.stride(0) = ", p.bias_strideB, ", and should be a "
1221+
"multiple of ", kMinimumAlignment, ".");
12201222
TORCH_CHECK(
12211223
p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0,
1222-
"attn_bias is not correctly aligned (strideH)");
1224+
"attn_bias is not correctly aligned (strideH) ."
1225+
"attn_bias.stride(1) = ", p.bias_strideH, ", and should be a "
1226+
"multiple of ", kMinimumAlignment, ".");
12231227
TORCH_CHECK(
1224-
p.bias_strideM % kMinimumAlignment == 0,
1225-
"attn_bias is not correctly aligned (strideM)");
1228+
p.num_queries <= 1 || p.bias_strideM % kMinimumAlignment == 0,
1229+
"attn_bias is not correctly aligned (strideM). "
1230+
"attn_bias.stride(2) = ", p.bias_strideM, ", and should be a ",
1231+
"multiple of ", kMinimumAlignment, ".");
12261232
}
12271233
if (p.grad_bias_ptr) {
12281234
TORCH_CHECK(

aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -578,13 +578,19 @@ struct AttentionKernel {
578578
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
579579
TORCH_CHECK(
580580
p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0,
581-
"attn_bias is not correctly aligned (strideB)");
581+
"attn_bias is not correctly aligned (strideB). ",
582+
"attn_bias.stride( 0) = ", p.bias_strideB, ", and should be a "
583+
"multiple of ", kAlignmentQ, ".");
582584
TORCH_CHECK(
583585
p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0,
584-
"attn_bias is not correctly aligned (strideH)");
586+
"attn_bias is not correctly aligned (strideH). "
587+
"attn_bias.stride(1) = ", p.bias_strideH, ", and should be a "
588+
"multiple of ", kAlignmentQ, ".");
585589
TORCH_CHECK(
586-
p.bias_strideM % kAlignmentQ == 0,
587-
"attn_bias is not correctly aligned");
590+
p.num_queries <= 1 || p.bias_strideM % kAlignmentQ == 0,
591+
"attn_bias is not correctly aligned (strideM). "
592+
"attn_bias.stride(2) = ", p.bias_strideM, ", and should be a "
593+
"multiple of ", kAlignmentQ, ".");
588594
}
589595
TORCH_CHECK(
590596
p.q_strideM % kAlignmentQ == 0,

test/inductor/test_torchinductor.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6441,6 +6441,52 @@ def fn_or(x, y):
64416441
(torch.randn(32), torch.randn(32)),
64426442
)
64436443

6444+
@requires_cuda()
6445+
@unittest.skipIf(
6446+
not PLATFORM_SUPPORTS_FUSED_SDPA,
6447+
"Does not support mem_eff_attention",
6448+
)
6449+
@skipIfRocm
6450+
def test_sdpa_unaligned_mask(self):
6451+
def foo(
6452+
arg0_1: "f32[8, 8, 16, 16]",
6453+
arg1_1: "f32[8, 8, 15, 16]",
6454+
arg2_1: "f32[8, 8, 15, 16]",
6455+
arg3_1: "f32[1, 1, 16, 15]",
6456+
):
6457+
constant_pad_nd: "f32[1, 1, 16, 16]" = (
6458+
torch.ops.aten.constant_pad_nd.default(arg3_1, [0, 1], 0.0)
6459+
)
6460+
arg3_1 = None
6461+
slice_1: "f32[1, 1, 16, 15]" = torch.ops.aten.slice.Tensor(
6462+
constant_pad_nd, -1, 0, 15
6463+
)
6464+
constant_pad_nd = None
6465+
expand: "f32[8, 8, 16, 15]" = torch.ops.aten.expand.default(
6466+
slice_1, [8, 8, 16, 15]
6467+
)
6468+
slice_1 = None
6469+
_scaled_dot_product_efficient_attention = (
6470+
torch.ops.aten._scaled_dot_product_efficient_attention.default(
6471+
arg0_1, arg1_1, arg2_1, expand, False
6472+
)
6473+
)
6474+
arg0_1 = arg1_1 = arg2_1 = expand = None
6475+
getitem: "f32[8, 8, 16, 16]" = _scaled_dot_product_efficient_attention[0]
6476+
_scaled_dot_product_efficient_attention = None
6477+
return (getitem,)
6478+
6479+
query = torch.rand(8, 8, 16, 16, device="cuda")
6480+
key = torch.rand(8, 8, 15, 16, device="cuda")
6481+
value = torch.rand(8, 8, 15, 16, device="cuda")
6482+
bias = torch.rand(1, 1, 16, 15, device="cuda")
6483+
self.common(
6484+
foo,
6485+
(query, key, value, bias),
6486+
atol=0.02,
6487+
rtol=1e4,
6488+
)
6489+
64446490
@skipIfRocm
64456491
def test_conv_with_as_strided(self):
64466492
class Model(nn.Module):

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def run(*ex, **kwargs):
259259
"test_zero_dim_reductions_dynamic_shapes": TestFailure(
260260
("cpu", "cuda"), is_skip=True
261261
),
262+
"test_sdpa_unaligned_mask_dynamic_shapes": TestFailure(("cpu",), is_skip=True),
262263
#
263264
# The following tests do not support dynamic shapes yet:
264265
#

test/test_transformers.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1760,7 +1760,6 @@ def test_mem_eff_attention_long_sequence_mask(self, device, dtype):
17601760
out = F.scaled_dot_product_attention(query, key, value, mask)
17611761
out.sum().backward()
17621762

1763-
17641763
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
17651764
@parametrize("type", ["dense", "nested"])
17661765
@parametrize("is_contiguous", [True, False])
@@ -1801,6 +1800,24 @@ def test_scaled_dot_product_attention_fused_kernels(self, device, type: str, is_
18011800

18021801
self.assertEqual(actual[0].contiguous(), math_ref[0].contiguous(), atol=1e-3, rtol=1e-2)
18031802

1803+
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
1804+
def test_mem_eff_attention_non_contig_mask_bug(self, device):
1805+
dtype = torch.float32
1806+
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
1807+
batch, num_heads, head_dim = 1, 16, 128
1808+
seq_len_q, seq_len_kv = 1, 16
1809+
query = make_tensor(batch, seq_len_q, num_heads * head_dim).view(batch, seq_len_q, num_heads, head_dim).transpose(1, 2)
1810+
kv_shape = (batch, seq_len_kv, head_dim)
1811+
key, value = make_tensor(kv_shape).unsqueeze(1), make_tensor(kv_shape).unsqueeze(1)
1812+
key = key.expand(-1, num_heads, -1, -1)
1813+
value = value.expand(-1, num_heads, -1, -1)
1814+
mask = torch.ones((1, 1, seq_len_q, seq_len_kv), device=device, dtype=torch.bool)
1815+
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
1816+
out = F.scaled_dot_product_attention(query, key, value, mask)
1817+
out_no_mask = F.scaled_dot_product_attention(query, key, value, None)
1818+
max_diff = (out - out_no_mask).abs().mean()
1819+
assert max_diff.item() < 1e-9
1820+
18041821
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
18051822
@parametrize("type", ["dense", "nested"])
18061823
@parametrize("is_contiguous", [True, False])

torch/_inductor/lowering.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,10 +1866,91 @@ def apply_constraint(arg, fx_arg):
18661866
make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
18671867
make_fallback(aten.grid_sampler_2d_backward, require_dense)
18681868
make_fallback(aten.randperm)
1869-
make_fallback(aten._scaled_dot_product_efficient_attention)
1870-
make_fallback(aten._scaled_dot_product_efficient_attention_backward)
1871-
make_fallback(aten._scaled_dot_product_flash_attention)
1872-
make_fallback(aten._scaled_dot_product_flash_attention_backward)
1869+
1870+
1871+
def sdpa_constraint(fx_node, *args, **kwargs):
1872+
# sdpa requires dense last dimension
1873+
def apply_constraint(arg, fx_arg):
1874+
if not isinstance(arg, ir.IRNode):
1875+
return arg
1876+
1877+
meta_val = fx_arg.meta["val"]
1878+
if not meta_val.is_cuda:
1879+
return arg
1880+
1881+
stride_order = ir.get_stride_order(meta_val.stride())
1882+
if stride_order and stride_order[-1] != 0:
1883+
# contiguous stride order
1884+
stride_order = list(reversed(range(len(arg.get_size()))))
1885+
1886+
# This is the minimum alignment required by SDPA kernels for attention_bias.
1887+
# This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask
1888+
ALIGNMENT = 8
1889+
1890+
is_backward = fx_node.target in (
1891+
aten._scaled_dot_product_efficient_attention_backward.default,
1892+
aten._scaled_dot_product_flash_attention_backward.default,
1893+
)
1894+
1895+
def is_aligned(x):
1896+
return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0
1897+
1898+
assert isinstance(arg, TensorBox)
1899+
1900+
# This correctly handles the forward case:
1901+
if isinstance(arg.data, (ir.SliceView, ir.ExpandView)):
1902+
if not is_aligned(arg):
1903+
# input is padded, requiring_stride_order will unwrap the view and unpad.
1904+
# Would be nice to be able to require certain padding from inductor ir, nyi
1905+
if is_aligned(arg.unwrap_view()):
1906+
return arg
1907+
1908+
def is_aligned_backward(x):
1909+
aligned_strides = all(
1910+
(V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0
1911+
for i in range(len(x.get_stride()) - 1)
1912+
)
1913+
return (
1914+
V.graph.sizevars.size_hint(x.get_stride()[-1])
1915+
) == 1 and aligned_strides
1916+
1917+
if (
1918+
isinstance(arg.data, ir.StorageBox)
1919+
and arg.data.is_input_buffer()
1920+
and is_backward
1921+
):
1922+
if len(arg.data.get_size()) == 4 and is_aligned_backward(arg):
1923+
return arg
1924+
1925+
return ir.ExternKernel.require_stride_order(arg, stride_order)
1926+
1927+
args = tuple(
1928+
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
1929+
)
1930+
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
1931+
return args, kwargs
1932+
1933+
1934+
make_fallback(
1935+
aten._scaled_dot_product_efficient_attention,
1936+
sdpa_constraint,
1937+
warn=False,
1938+
)
1939+
make_fallback(
1940+
aten._scaled_dot_product_efficient_attention_backward,
1941+
sdpa_constraint,
1942+
warn=False,
1943+
)
1944+
make_fallback(
1945+
aten._scaled_dot_product_flash_attention,
1946+
sdpa_constraint,
1947+
warn=False,
1948+
)
1949+
make_fallback(
1950+
aten._scaled_dot_product_flash_attention_backward,
1951+
sdpa_constraint,
1952+
warn=False,
1953+
)
18731954
make_fallback(aten.sort)
18741955
make_fallback(aten.sort.stable)
18751956
make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)

torch/_meta_registrations.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4985,12 +4985,14 @@ def meta__scaled_dot_product_efficient_backward(
49854985
)
49864986
grad_bias = None
49874987
if attn_bias is not None and grad_input_mask[3]:
4988-
grad_bias = torch.empty_strided(
4989-
attn_bias.size(),
4990-
attn_bias.stride(),
4991-
dtype=attn_bias.dtype,
4992-
device=attn_bias.device,
4988+
lastDim = attn_bias.size(-1)
4989+
lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
4990+
new_sizes = list(attn_bias.size())
4991+
new_sizes[-1] = lastDimAligned
4992+
grad_bias = torch.empty(
4993+
new_sizes, dtype=attn_bias.dtype, device=attn_bias.device
49934994
)
4995+
grad_bias = grad_bias[..., :lastDim]
49944996

49954997
return grad_q, grad_k, grad_v, grad_bias
49964998

0 commit comments

Comments
 (0)