Skip to content

Commit 7ac294f

Browse files
committed
AOTriton 0.7.1 compile fix
1 parent a1e8b0e commit 7ac294f

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

aten/src/ATen/native/transformers/cuda/attention_backward.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ _efficient_attention_backward(
421421
using aotriton::v2::flash::attn_bwd;
422422
using sdp::aotriton_adapter::mk_aotensor;
423423
using sdp::aotriton_adapter::cast_dtype;
424+
using sdp::aotriton_adapter::mk_aoscalartensor;
424425
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
425426
err = attn_bwd(mk_aotensor(q_t, "q"),
426427
mk_aotensor(k_t, "k"),

aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
433433
using aotriton::v2::flash::attn_bwd;
434434
using sdp::aotriton_adapter::mk_aotensor;
435435
using sdp::aotriton_adapter::cast_dtype;
436+
using sdp::aotriton_adapter::mk_aoscalartensor;
436437
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
437438
err = attn_bwd(mk_aotensor(q_t, "q"),
438439
mk_aotensor(k_t, "k"),

0 commit comments

Comments
 (0)