Skip to content

Add tests for cross attention #7609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 129 additions & 24 deletions extension/llm/modules/test/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def setUp(self):
self.num_kv_heads = 8
self.head_dim = 64
self.max_seq_len = 128
self.encoder_max_seq_len = 128
self.rope_base = 500_000
self.scale_factor = 32

Expand Down Expand Up @@ -86,16 +87,26 @@ def setUp(self):
max_seq_len=self.max_seq_len,
)
self.et_mha.load_state_dict(self.tt_mha.state_dict())

# Common inputs.
seq_len = 10
self.x = torch.randn(1, seq_len, self.embed_dim)
self.y = torch.randn(1, seq_len, self.embed_dim)
self.input_pos = torch.arange(seq_len).unsqueeze(0) # shape [1, seq_len]
seq_len_dim = torch.export.Dim("seq_len", min=1, max=100)
self.dynamic_shapes = (
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
{0: torch.export.Dim.STATIC, 1: seq_len_dim},
)
self.seq_len_dim = torch.export.Dim("seq_len", min=1, max=self.max_seq_len)
self.dynamic_shapes = {
"x": {
0: torch.export.Dim.STATIC,
1: self.seq_len_dim,
2: torch.export.Dim.STATIC,
},
"y": {
0: torch.export.Dim.STATIC,
1: self.seq_len_dim,
2: torch.export.Dim.STATIC,
},
"input_pos": {0: torch.export.Dim.STATIC, 1: self.seq_len_dim},
}
self.causal_mask = torch.tril(
torch.ones(
size=(self.max_seq_len, self.max_seq_len),
Expand All @@ -110,8 +121,8 @@ def test_attention_eager(self):
assert_close(et_res, tt_res)

# test with kv cache
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

et_res = self.et_mha(self.x, self.x) # Self attention.
tt_res = self.tt_mha(self.x, self.x) # Self attention.
Expand Down Expand Up @@ -144,12 +155,12 @@ def test_attention_export(self):
# Self attention.

# test with kv cache
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.x),
(self.x, self.y),
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
strict=True,
Expand All @@ -166,8 +177,8 @@ def test_attention_aoti(self):
# Self attention.

# test with kv cache
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
with torch.no_grad():
so = torch._export.aot_compile(
self.et_mha,
Expand All @@ -189,13 +200,13 @@ def test_attention_aoti(self):

def test_attention_executorch(self):
# Self attention.
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.x),
(self.x, self.y),
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
strict=True,
Expand All @@ -222,22 +233,18 @@ def test_attention_executorch(self):

def test_attention_torch_cond_eager(self):
# Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition.
# For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
# For the first run of MHA we provide `y` but for the second run it will be a tensor full of nan.
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

mask = self.causal_mask[self.input_pos, :]
# First run.
et_res = self.et_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.
tt_res = self.tt_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.
et_res = self.et_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)
tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)

assert_close(et_res, tt_res)

# Second run test kv cache read. Input pos is [10, 11, ..., 19]
# Second run tests kv cache read. Input pos is [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)

empty_y = torch.full_like(self.x, torch.nan)
Expand All @@ -246,3 +253,101 @@ def test_attention_torch_cond_eager(self):
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)

assert_close(et_res, tt_res)

def test_attention_torch_cond_export(self):
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
mask = self.causal_mask[self.input_pos, :]
dynamic_shapes = {
**self.dynamic_shapes,
**{
"mask": {
0: torch.export.Dim.STATIC,
1: self.seq_len_dim,
2: torch.export.Dim.STATIC,
}
},
}
with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.y),
kwargs={
"mask": mask,
"input_pos": self.input_pos,
},
dynamic_shapes=dynamic_shapes,
strict=True,
)

# First run.
et_res = et_mha_ep.module()(self.x, self.y, mask=mask, input_pos=self.input_pos)
tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)

assert_close(et_res, tt_res)

# Second run tests kv cache read. Input pos is [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)
empty_y = torch.full_like(self.y, torch.nan)
mask = self.causal_mask[next_input_pos, :]
et_res = et_mha_ep.module()(
self.x, empty_y, mask=mask, input_pos=next_input_pos
)
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)

assert_close(et_res, tt_res)

def test_attention_torch_cond_executorch(self):
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
mask = self.causal_mask[self.input_pos, :]
dynamic_shapes = {
**self.dynamic_shapes,
**{
"mask": {
0: torch.export.Dim.STATIC,
1: self.seq_len_dim,
2: torch.export.Dim.STATIC,
}
},
}
with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.y),
kwargs={
"mask": mask,
"input_pos": self.input_pos,
},
dynamic_shapes=dynamic_shapes,
strict=True,
)
et_program = to_edge(
et_mha_ep,
compile_config=EdgeCompileConfig(
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
_check_ir_validity=False,
),
).to_executorch(
config=ExecutorchBackendConfig(
passes=[InitializedMutableBufferPass(["cache_pos"])],
)
)

# First run.
runtime = Runtime.get()
program = runtime.load_program(et_program.buffer)
method = program.load_method("forward")
et_res = method.execute((self.x, self.y, mask, self.input_pos))
tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)

assert_close(et_res[0], tt_res)

# Second run tests kv cache read. Input pos is [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)
empty_y = torch.full_like(self.y, torch.nan)
mask = self.causal_mask[next_input_pos, :]
et_res = method.execute((self.x, empty_y, mask, next_input_pos))
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)

assert_close(et_res[0], tt_res)
Loading