Skip to content

Commit 046a06c

Browse files
xinyazhangpruthvistony
authored andcommitted
CONSOLIDATED COMMITS: Bump to AOTriton 0.7.1b
============================================= Bump to AOTriton 0.7.1b (#1572) A cherry-picked version of pytorch#134498 for rocm6.3_internal_testing (cherry picked from commit d28d7ff) AOTriton 0.7.1 compile fix (cherry picked from commit 7ac294f)
1 parent 6f76e3d commit 046a06c

File tree

4 files changed

+7
-0
lines changed

4 files changed

+7
-0
lines changed

aten/src/ATen/native/transformers/cuda/attention_backward.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,7 @@ _efficient_attention_backward(
445445
using sdp::aotriton_adapter::mk_aotensor;
446446
using sdp::aotriton_adapter::mk_aoscalartensor;
447447
using sdp::aotriton_adapter::cast_dtype;
448+
using sdp::aotriton_adapter::mk_aoscalartensor;
448449
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
449450
if (cu_seqlens_q.has_value()) {
450451
// varlen aka Nested tensor

aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
533533
using sdp::aotriton_adapter::mk_aotensor;
534534
using sdp::aotriton_adapter::mk_aoscalartensor;
535535
using sdp::aotriton_adapter::cast_dtype;
536+
using sdp::aotriton_adapter::mk_aoscalartensor;
536537
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
537538
err = attn_bwd(mk_aotensor(q_t, "q"),
538539
mk_aotensor(k_t, "k"),

test/inductor/test_flex_decoding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ def mask_mod(b, h, q, kv):
13831383
loss.backward()
13841384
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
13851385

1386+
@skipIfRocm
13861387
@supported_platform
13871388
def test_windowed_no_mask_vs_sdpa(self):
13881389
score_mod = _generate_windowed(1000)

test/test_transformers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3518,6 +3518,10 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d
35183518
torch.rand_like(query, device=query.device) # test non-zero intragraph offset
35193519
# Create real output
35203520
output_tuple = fused_op(query, key, value, **kwargs)
3521+
# for o in output_tuple:
3522+
# print(f'{o.__class__=}')
3523+
# if isinstance(o, torch.Tensor):
3524+
# print(f'{o.is_cuda=}')
35213525
assert all(not isinstance(o, torch.Tensor) or o.is_cuda for o in output_tuple)
35223526
g.replay()
35233527
out_first = output_tuple[0].clone()

0 commit comments

Comments
 (0)