Skip to content

Qualcomm AI Engine Direct - Suport batch prefill mode for llama3.2 #6983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from torch.fx import Node


def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
def annotate_matmul_16a8w( # noqa: C901
gm: torch.fx.GraphModule, traverse_input1=True
) -> None:
"""
This function is specific for matmul op 16a8w.
"""
Expand Down Expand Up @@ -99,7 +101,8 @@ def annotate_matmul_input1(node: Node):
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
annotate_matmul(node, quantization_config_16a8w)
annotate_matmul_input1(node.args[1])
if traverse_input1:
annotate_matmul_input1(node.args[1])


def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
Expand Down
131 changes: 120 additions & 11 deletions examples/qualcomm/oss_scripts/llama2/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,15 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None
)


def calibrate(
def _kv_calibrate(
example_inputs,
user_prompts,
module: torch.fx.GraphModule,
tokenizer_model_path="tokenizer.model",
max_seq_len=512,
):
sp_model = SentencePieceProcessor(model_file=tokenizer_model_path)
_, _, atten_mask, k_caches, v_caches = example_inputs
_, atten_mask, _, k_caches, v_caches = example_inputs

# TODO: change criteria & support batch inputs if necessary
pos = torch.tensor(0, dtype=torch.int32)
Expand All @@ -202,11 +203,11 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
return probs_indices.gather(dim=-1, index=next_token)

with torch.no_grad():
while token_list[-1] != sp_model.eos_id() and pos < 128:
while token_list[-1] != sp_model.eos_id() and pos < max_seq_len - 1:
logits, new_k_caches, new_v_caches = module(
torch.full((1, 1), token_list[pos]),
torch.full((1, 1), pos),
atten_mask,
torch.full((1, 1), pos),
*k_caches,
*v_caches,
)
Expand All @@ -228,15 +229,84 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
print(f"calibration data:\n{sp_model.decode(token_list)}")


def _batch_prefill_calibrate(
example_inputs,
user_prompts,
module: torch.fx.GraphModule,
tokenizer_model_path="tokenizer.model",
max_seq_len=512,
):
sp_model = SentencePieceProcessor(model_file=tokenizer_model_path)
_, atten_mask = example_inputs
max_cache_len = max_seq_len - 1

# TODO: change criteria & support batch inputs if necessary
token_list = sp_model.encode(user_prompts, bos=True, eos=False)
token_list = torch.tensor(token_list)[:max_cache_len].reshape(1, -1)
last_prompt_pos = token_list.numel()
if last_prompt_pos < max_cache_len:
token_list = torch.cat(
[
token_list,
torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int32),
],
dim=1,
)
else:
token_list = token_list[:, :max_cache_len]

with torch.no_grad():
logits, new_k_caches, new_v_caches = module(
token_list,
atten_mask,
)
predict = [torch.argmax(logits[:, last_prompt_pos - 1], dim=-1).item()]

print(f"calibration data:\n{sp_model.decode(predict)}")


def calibrate(
example_inputs,
user_prompts,
module: torch.fx.GraphModule,
tokenizer_model_path="tokenizer.model",
max_seq_len=512,
):
if len(example_inputs) == 2:
_batch_prefill_calibrate(
example_inputs,
user_prompts,
module,
tokenizer_model_path,
max_seq_len,
)
elif len(example_inputs) == 5:
_kv_calibrate(
example_inputs,
user_prompts,
module,
tokenizer_model_path,
max_seq_len,
)
else:
raise RuntimeError("Get wrong inputs")


class SingleLlama:
def __init__(self, llama_model) -> None:
super().__init__()
self.llama_model = llama_model
self.quant_dtype = None
self.llama_meta = self.llama_model.get_metadata()
self.has_quant_io = False
tokens, pos_ids, atten_mask, k_caches, v_caches = self.get_example_inputs()
self.inputs = (tokens, pos_ids, atten_mask, *k_caches, *v_caches)
if self.llama_meta["get_use_kv_cache"]:
tokens, atten_mask, pos_ids, k_caches, v_caches = self.get_example_inputs(
use_kv_cache=True
)
self.inputs = (tokens, atten_mask, pos_ids, *k_caches, *v_caches)
else:
tokens, atten_mask = self.get_example_inputs(use_kv_cache=False)
self.inputs = (tokens, atten_mask)

def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type):
if not self.has_quant_io:
Expand All @@ -256,11 +326,17 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type):
n.meta[QCOM_QUANTIZED_IO] = kv_type
elif n.op == "output":
for a in n.args[0]:
# single head, kv mode
if (
a.meta["val"].flatten().size()[0]
== self.llama_meta["get_head_dim"]
):
a.meta[QCOM_QUANTIZED_IO] = kv_type
# single head, batch_prefill mode
elif a.meta["val"].flatten().size()[0] == self.llama_meta[
"get_head_dim"
] * (self.llama_meta["get_max_seq_len"] - 1):
a.meta[QCOM_QUANTIZED_IO] = kv_type

def quantize(self, quant_dtype, custom_annotations=()):
self.quant_dtype = quant_dtype
Expand All @@ -281,11 +357,13 @@ def quantize(self, quant_dtype, custom_annotations=()):
).module()
fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
print("Quantizing the model...")

calibrate(
self.get_example_inputs(),
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
args.prompt,
fx_graph_module,
tokenizer_model_path=args.tokenizer_model,
max_seq_len=args.seq_len,
)

self.llama_model = convert_pt2e(fx_graph_module)
Expand Down Expand Up @@ -328,18 +406,29 @@ def lowering_modules(
with open(f"{work_space}/{pte_filename}.pte", "wb") as file:
exec_prog_mgr.write_to_file(file)

def get_example_inputs(self):
return self.llama_model.get_example_inputs()
def get_example_inputs(self, use_kv_cache=True):
return self.llama_model.get_example_inputs(use_kv_cache)


def compile(args):
os.makedirs(args.artifact, exist_ok=True)
start_ts = time.time()

if args.model_mode == "kv":
use_kv_cache = output_new_cache_only = True
elif args.model_mode == "batch_prefill" or args.model_mode == "hybrid":
raise NotImplementedError(
f"model_mode {args.model_mode} is not implemented yet."
)
else:
raise RuntimeError(f"No such model_mode {args.model_mode}.")

with open(args.params) as f:
config = ModelArgs(**json.load(f))
# TODO: support batch inputs if necessary
config.max_batch_size = 1
config.max_seq_len = 1024
config.max_seq_len = args.seq_len
config.use_kv_cache = use_kv_cache
state_dict = torch.load(
args.checkpoint, weights_only=True, map_location="cpu", mmap=True
)
Expand All @@ -348,7 +437,7 @@ def compile(args):

llama_instance = None
with torch.device("meta"):
llama_instance = LlamaModel(config, output_new_cache_only=True)
llama_instance = LlamaModel(config, output_new_cache_only=output_new_cache_only)
if "model" in state_dict:
state_dict = state_dict["model"]
llama_instance.load_state_dict(
Expand Down Expand Up @@ -398,6 +487,11 @@ def compile(args):
def inference(args, pre_gen_pte=""):
workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama"

if args.model_mode != "kv":
raise NotImplementedError(
f"model_mode {args.model_mode} is not implemented yet."
)

runner_args = " ".join(
[
f"--model_path {pte_filename}.pte",
Expand Down Expand Up @@ -550,6 +644,21 @@ def post_process():
type=str,
)

parser.add_argument(
"--num_sharding",
type=int,
default=0,
help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.",
)

parser.add_argument(
"--model_mode",
help="Export and inference batch_prefill mode, kv mode or hybrid(TBD) mode",
default="kv",
choices=["batch_prefill", "kv", "hybrid"],
type=str,
)

args = parser.parse_args()
if args.compile_only and args.pre_gen_pte:
exit("Cannot set both compile_only and pre_gen_pte as true")
Expand Down
Loading
Loading