diff --git a/python/llm/example/GPU/Speculative-Decoding/EAGLE/README.md b/python/llm/example/GPU/Speculative-Decoding/EAGLE/README.md index 611c86c6..653d6f19 100644 --- a/python/llm/example/GPU/Speculative-Decoding/EAGLE/README.md +++ b/python/llm/example/GPU/Speculative-Decoding/EAGLE/README.md @@ -17,8 +17,9 @@ Step 3, you also need to download and install [Intel® oneAPI Base Toolkit](http - Intel Data Center GPU Max Series - Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series -## Example - EAGLE Speculative Sampling with IPEX-LLM on MT-bench +## Example - EAGLE-2 Speculative Sampling with IPEX-LLM on MT-bench In this example, we run inference for a Llama2 model to showcase the speed of EAGLE with IPEX-LLM on MT-bench data on Intel GPUs. +We use EAGLE-2 which have better performance than EAGLE-1 ### 1. Install #### 1.1 Installation on Linux @@ -28,8 +29,10 @@ conda create -n llm python=3.11 conda activate llm # below command will install intel_extension_for_pytorch==2.1.10+xpu as default pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -pip install eagle-llm +git clone https://github.com/SafeAILab/EAGLE.git +cd EAGLE pip install -r requirements.txt +pip install -e . ``` #### 1.2 Installation on Windows @@ -42,10 +45,10 @@ pip install dpcpp-cpp-rt==2024.0.2 mkl-dpcpp==2024.0.0 onednn==2024.0.0 # below command will install intel_extension_for_pytorch==2.1.10+xpu as default pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +git clone https://github.com/SafeAILab/EAGLE.git +cd EAGLE pip install -r requirements.txt -pip install transformers==4.36.2 -pip install gradio==3.50.2 -pip install eagle-llm +pip install -e . ``` ### 2. Configures OneAPI environment variables for Linux @@ -89,7 +92,7 @@ export ENABLE_SDP_FUSION=1 ### 4. Running Example You can test the speed of EAGLE speculative sampling with ipex-llm on MT-bench using the following command. ```bash -python -m evaluation.gen_ea_answer_llama2chat\ +python -m evaluation.gen_ea_answer_llama2chat_e2_ipex_optimize\ --ea-model-path [path of EAGLE weight]\ --base-model-path [path of the original model]\ --enable-ipex-llm\ diff --git a/python/llm/example/GPU/Speculative-Decoding/EAGLE/evaluation/gen_ea_answer_llama2chat.py b/python/llm/example/GPU/Speculative-Decoding/EAGLE/evaluation/gen_ea_answer_llama2chat_e2_ipex_optimize.py similarity index 58% rename from python/llm/example/GPU/Speculative-Decoding/EAGLE/evaluation/gen_ea_answer_llama2chat.py rename to python/llm/example/GPU/Speculative-Decoding/EAGLE/evaluation/gen_ea_answer_llama2chat_e2_ipex_optimize.py index ffa9adfb..091bef31 100644 --- a/python/llm/example/GPU/Speculative-Decoding/EAGLE/evaluation/gen_ea_answer_llama2chat.py +++ b/python/llm/example/GPU/Speculative-Decoding/EAGLE/evaluation/gen_ea_answer_llama2chat_e2_ipex_optimize.py @@ -34,106 +34,224 @@ import argparse import json import os -from accelerate.utils import set_seed -set_seed(0) import time +import torch + import shortuuid from fastchat.llm_judge.common import load_questions from fastchat.model import get_conversation_template from tqdm import tqdm from eagle.model.ea_model import EaModel -from eagle.model.utils import * +from eagle.model.utils import prepare_logits_processor, evaluate_posterior from eagle.model.kv_cache import initialize_past_key_values from eagle.model.choices import * +from eagle.modeling_eagle import forward_with_tree_mask from ipex_llm import optimize_model +from ipex_llm.transformers import AutoModelForCausalLM -def ea_forward(input_ids, model, tokenizer, tree_choices, logits_processor=None, max_steps=512): - assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" - # Avoid modifying the input_ids in-place - input_ids = input_ids.clone() - model.ea_layer.reset_kv() +class Timer: + def __init__(self,name): + self.name = name - if hasattr(model, "tree_choices") and model.tree_choices == tree_choices: - tree_buffers = model.tree_buffers + def __enter__(self): + torch.xpu.synchronize() + self.start = time.perf_counter() + + + def __exit__(self, exc_type, exc_value, traceback): + torch.xpu.synchronize() + elapsed = time.perf_counter() - self.start + print(f'{self.name} took {elapsed} seconds') + + +def initialize_tree(input_ids, model, logits_processor): + # outputs, orig, hidden_states = model( + # input_ids, past_key_values=past_key_values, output_orig=True + # ) + hidden_states, past_key_values = forward_with_tree_mask(model.base_model.model, input_ids=input_ids) + orig = model.base_model.lm_head(hidden_states) + + if logits_processor is not None: + logits = orig[:, -1] + logits = logits_processor(None, logits) + probabilities = torch.nn.functional.softmax(logits, dim=1) + token = torch.multinomial(probabilities, 1) else: - tree_buffers = generate_tree_buffers( - tree_choices, device=model.base_model.model.layers[-1].self_attn.q_proj.weight.device - ) - tree_buffers["retrieve_indices_head"] = tree_buffers["retrieve_indices"].to( - model.base_model.lm_head.weight.device) - model.tree_buffers = tree_buffers - model.tree_choices = tree_choices + token = torch.argmax(orig[:, -1]) + token = token[None, None] + input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1) + # Clone the output hidden states - # Initialize the past key and value states - if hasattr(model, "past_key_values"): - past_key_values = model.past_key_values - past_key_values_data = model.past_key_values_data - current_length_data = model.current_length_data - # Reset the past key and value states - current_length_data.zero_() - else: - ( - past_key_values, - past_key_values_data, - current_length_data, - ) = initialize_past_key_values(model.base_model) - model.past_key_values = past_key_values - model.past_key_values_data = past_key_values_data - model.current_length_data = current_length_data + draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor) + return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token, past_key_values - input_len = input_ids.shape[1] - reset_tree_mode(model) - tree_logits, logits, hidden_state, sample_token = initialize_tree( - input_ids, model, tree_buffers["tree_attn_mask"], past_key_values, logits_processor +def tree_decoding( + model, + tree_candidates, + past_key_values, + tree_position_ids, + input_ids, + retrieve_indices, + attention_mask=None, + tree_mask=None, +): + position_ids = tree_position_ids + input_ids.shape[1] + + hidden_states, past_key_values = forward_with_tree_mask(model.base_model.model, input_ids=tree_candidates, past_key_values=past_key_values, + position_ids=position_ids, tree_mask=tree_mask, attention_mask=attention_mask) + + tree_logits = model.base_model.lm_head(hidden_states) + + logits = tree_logits[0, retrieve_indices] + + return logits, hidden_states, past_key_values + + +def update_inference_inputs( + input_ids, + candidates, + best_candidate, + accept_length, + retrieve_indices, + logits_processor, + new_token, + past_key_values, + # current_length_data, + model, + hidden_state_new, + sample_p +): + prev_input_len = input_ids.shape[1] + # Map the best candidate indices to the original indices in the sequence + select_indices = ( + retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len + ) + # Append the tokens from the best candidate to the input sequence + input_ids = torch.cat( + [input_ids, candidates[None, best_candidate, : accept_length + 1].to(input_ids.device)], dim=-1 ) - new_token = 0 - for idx in range(max_steps): - candidates, cart_candidates_prob, tree_candidates = generate_candidates( - tree_logits, - tree_buffers["tree_indices"], - tree_buffers["retrieve_indices"], - sample_token, - logits_processor - ) - logits, hidden_state_new, outputs = tree_decoding( + new_kv = () + + for past_key_values_data in past_key_values: + layer_kv = () + for korv in past_key_values_data: + tgt = korv[:, :, select_indices, :] + dst = korv[:, :, prev_input_len: prev_input_len + tgt.shape[-2], :] + dst.copy_(tgt, non_blocking=True) + layer_kv += (korv[:, :, : prev_input_len + tgt.shape[-2], :],) + new_kv += (layer_kv,) + + retrieve_hidden_state_new = hidden_state_new[:, retrieve_indices] + accept_hidden_state_new = retrieve_hidden_state_new[:, best_candidate, : accept_length + 1] + + prob = sample_p + if logits_processor is not None: + token = torch.multinomial(prob, 1) + token = token[None] + else: + token = torch.argmax(prob) + token = token[None, None] + + draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(accept_hidden_state_new, + input_ids=torch.cat((input_ids, token.to(input_ids.device)), dim=1), + head=model.base_model.lm_head,logits_processor=logits_processor) + + + new_token += accept_length + 1 + + return input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, None, token, new_kv + + +def eagenerate( model, - tree_candidates, - past_key_values, - tree_buffers["tree_position_ids"], input_ids, - tree_buffers["retrieve_indices_head"], + temperature=0.0, + top_p=0.0, + top_k=0.0, + max_new_tokens=512, + max_length=2048, + log=False, + is_llama3=False, + + ): + if is_llama3: + stop_token_id = model.tokenizer.convert_tokens_to_ids("<|eot_id|>") + max_length=max_length-model.ea_layer.total_tokens-10 + + if temperature > 1e-5: + logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) + else: + logits_processor = None + #assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" + # Avoid modifying the input_ids in-place + + padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device) + input_ids = input_ids.clone() + model.ea_layer.reset_kv() + + input_len = input_ids.shape[1] + # with Timer("initialize_tree"): + draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token, past_key_values = initialize_tree( + input_ids, model, logits_processor ) - best_candidate, accept_length, sample_p = evaluate_posterior( - logits, candidates, logits_processor, cart_candidates_prob, tree_logits[2], tree_buffers["p_indices"], - tree_candidates, tree_buffers["b_indices"] - ) - input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs( - input_ids, - candidates, - best_candidate, - accept_length, - tree_buffers["retrieve_indices"], - logits_processor, - logits, - tree_logits, - new_token, - past_key_values_data, - current_length_data, - model, - hidden_state, - hidden_state_new, - sample_p - ) - if tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): - break - if new_token > 1024: - break - if input_ids.shape[1] > 1960: - break - return input_ids, new_token, idx + + new_token = 0 + + for idx in range(max_length): + # with Timer("all"): + draft_tokens=draft_tokens.to(input_ids.device) + #with Timer("tree_decoding"): + logits, hidden_state_new, past_key_values = tree_decoding( + model, + draft_tokens, + past_key_values, + tree_position_ids, + input_ids, + retrieve_indices, + tree_mask=tree_mask, + ) + + draft_tokens=torch.cat((draft_tokens,padding),dim=1) + candidates=draft_tokens[0,retrieve_indices] + # with Timer("evaluate_posterior"): + best_candidate, accept_length, sample_p = evaluate_posterior( + logits, candidates, logits_processor + ) + + # print("new_token: ", (accept_length+1).item()) + # with Timer("update_inference_inputs"): + input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, hidden_state, sample_token, past_key_values = update_inference_inputs( + input_ids, + candidates, + best_candidate, + accept_length, + retrieve_indices, + logits_processor, + new_token, + past_key_values, + # current_length_data, + model, + hidden_state_new, + sample_p + ) + + if is_llama3: + if stop_token_id in input_ids[0, input_len:].tolist(): + break + + if model.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): + break + if new_token > max_new_tokens: + break + if input_ids.shape[1] > max_length: + break + if not log: + return input_ids + else: + return input_ids, new_token, idx def run_eval( @@ -147,8 +265,8 @@ def run_eval( max_new_token, num_choices, temperature, - tree_choices, enable_ipex_llm, + args, ): questions = load_questions(question_file, question_begin, question_end) shuffled_ids = [q["question_id"] for q in questions] @@ -168,8 +286,8 @@ def run_eval( max_new_token, num_choices, temperature, - tree_choices, enable_ipex_llm, + args ) ) @@ -184,34 +302,32 @@ def get_model_answers( max_new_token, num_choices, temperature, - tree_choices, enable_ipex_llm, + args ): - try: - model = EaModel.from_pretrained( + + model = EaModel.from_pretrained( base_model_path=base_model_path, ea_model_path=ea_model_path, - #torch_dtype=torch.float16, - torch_dtype=torch.float32, + torch_dtype=torch.float16, + # torch_dtype=torch.float32, low_cpu_mem_usage=True, # load_in_8bit=True, - device_map="auto" - ) - except ValueError: - print("Using sequential device_map.") - model = EaModel.from_pretrained( - base_model_path=base_model_path, - ea_model_path=ea_model_path, - #torch_dtype=torch.float16, - torch_dtype=torch.float32, - low_cpu_mem_usage=True, - # load_in_8bit=True, - device_map="sequential" + total_token=args.total_token, + depth=args.depth, + top_k=args.top_k, + ) if enable_ipex_llm: - # single line of change to enable ipex-llm - model = optimize_model(model, low_bit='sym_int4', optimize_llm=False) - model.to("xpu") + # Use optimized int4 base model to replace base model in EaModel + base_model = AutoModelForCausalLM.from_pretrained(base_model_path, load_in_4bit=True, use_cache=True, + torch_dtype=torch.float16) + model.base_model = base_model + # Also optimize draft model in EaModel + model.ea_layer = optimize_model(model.ea_layer) + model = model.to("xpu") + model.ea_layer.tree_mask_init = model.ea_layer.tree_mask_init.to("xpu") + model.ea_layer.position_ids = model.ea_layer.position_ids.to("xpu") tokenizer = model.get_tokenizer() if temperature > 1e-5: @@ -246,15 +362,17 @@ def get_model_answers( inputs = tokenizer([prompt], return_tensors="pt").to("xpu") input_ids = inputs.input_ids + torch.xpu.synchronize() start_time = time.time() - output_ids, new_token, idx = ea_forward( - torch.as_tensor(input_ids), - model, - tokenizer, - tree_choices, - logits_processor, - ) + output_ids, new_token, idx = eagenerate( + model, + torch.as_tensor(input_ids), + temperature=temperature, + max_new_tokens=max_new_token, + log=True + ) + torch.xpu.synchronize() total_time = time.time() - start_time output_ids = output_ids[0][len(input_ids[0]):] # be consistent with the template's stop_token_ids @@ -311,14 +429,16 @@ def get_model_answers( input_ids = inputs.input_ids try: + torch.xpu.synchronize() start_time = time.time() - output_ids, new_token, idx = ea_forward( - torch.as_tensor(input_ids), + output_ids, new_token, idx = eagenerate( model, - tokenizer, - tree_choices, - logits_processor, + torch.as_tensor(input_ids), + temperature=temperature, + max_new_tokens=max_new_token, + log=True, ) + torch.xpu.synchronize() total_time = time.time() - start_time output_ids = output_ids[0][len(input_ids[0]):] @@ -417,6 +537,22 @@ if __name__ == "__main__": default=1024, help="The maximum number of new generated tokens.", ) + parser.add_argument( + "--total-token", + type=int, + default=8, + help="The total number of nodes in the draft tree", + ) + parser.add_argument( + "--depth", + type=int, + default=4, + ) + parser.add_argument( + "--top-k", + type=int, + default=5, + ) parser.add_argument( "--num-choices", type=int, @@ -426,7 +562,7 @@ if __name__ == "__main__": parser.add_argument( "--temperature", type=float, - default=1.0, + default=0.0, ) parser.add_argument( @@ -443,7 +579,7 @@ if __name__ == "__main__": args = parser.parse_args() args.model_id = args.model_id + "-temperature-" + str(args.temperature) - args.tree_choices = eval(args.tree_choices) + # args.tree_choices = eval(args.tree_choices) question_file = f"data/{args.bench_name}/question.jsonl" if args.answer_file: @@ -464,8 +600,8 @@ if __name__ == "__main__": args.max_new_token, args.num_choices, args.temperature, - args.tree_choices, args.enable_ipex_llm, + args ) reorg_answer_file(answer_file)