Skip to content

Commit 923ff39

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
testing correctness between kv cache, non kvcache, and eager
Differential Revision: D55465401
1 parent 72de6f3 commit 923ff39

File tree

6 files changed

+157501
-1
lines changed

6 files changed

+157501
-1
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ runtime.python_binary(
5757
deps = [
5858
":export_library",
5959
"//caffe2:torch",
60-
"//executorch/extension/pybindings:aten_lib",
60+
"//executorch/extension/pybindings:portable_lib",
6161
],
6262
)
6363

examples/models/llama2/export_kv_d32.txt

Lines changed: 40761 additions & 0 deletions
Large diffs are not rendered by default.

examples/models/llama2/export_kv_q.txt

Lines changed: 41756 additions & 0 deletions
Large diffs are not rendered by default.

examples/models/llama2/export_llama_lib.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626

2727
from executorch.examples.models.llama2.llama_transformer import Transformer
2828
from executorch.exir.backend.backend_details import CompileSpec
29+
from executorch.extension.pybindings.portable_lib import (
30+
_load_for_executorch_from_buffer,
31+
)
2932

3033
from executorch.sdk.etrecord import generate_etrecord
3134
from executorch.util.activation_memory_profiler import generate_memory_trace
@@ -774,6 +777,33 @@ def _export_llama(modelname, args) -> str: # noqa: C901
774777
else:
775778
output_file = f"{builder.output_dir}/{modelname}.pte"
776779

780+
builder.export_program.dump_executorch_program(True)
781+
et_model = _load_for_executorch_from_buffer(builder.export_program.buffer)
782+
eager_model = builder.model
783+
input1 = torch.tensor([[0]], dtype=torch.long)
784+
input2 = torch.tensor([0], dtype=torch.long)
785+
if args.use_kv_cache:
786+
for i in range(0, 100):
787+
eager_res = eager_model(input1, input2)
788+
et_res = et_model((input1, input2))
789+
assert len(et_res) == 1
790+
print("eager res kv", i, eager_res)
791+
print("et res kv", i, et_res[0])
792+
# assert torch.allclose(eager_res, et_res[0], atol=1e-05, rtol=1e-05)
793+
input1 = torch.tensor([[i + 1]], dtype=torch.long)
794+
input2 = torch.tensor([i + 1], dtype=torch.long)
795+
else:
796+
for i in range(0, 100):
797+
eager_res = eager_model(input1)
798+
et_res = et_model((input1,))
799+
assert len(et_res) == 1
800+
print("eager res", i, eager_res)
801+
print("et res", i, et_res[0])
802+
# assert torch.allclose(eager_res, et_res[0], atol=1e-05, rtol=1e-05)
803+
input1 = torch.cat(
804+
[input1, torch.tensor([[i + 1]], dtype=torch.long)], dim=1
805+
)
806+
777807
builder.save_to_pte(output_file)
778808

779809
return output_file

0 commit comments

Comments
 (0)