Update Eagle example to Eagle2+ipex-llm integration (#11717)
* update to e2 example * update * update
This commit is contained in:
parent
26390f9213
commit
667f0db466
2 changed files with 261 additions and 122 deletions
|
|
@ -17,8 +17,9 @@ Step 3, you also need to download and install [Intel® oneAPI Base Toolkit](http
|
|||
- Intel Data Center GPU Max Series
|
||||
- Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series
|
||||
|
||||
## Example - EAGLE Speculative Sampling with IPEX-LLM on MT-bench
|
||||
## Example - EAGLE-2 Speculative Sampling with IPEX-LLM on MT-bench
|
||||
In this example, we run inference for a Llama2 model to showcase the speed of EAGLE with IPEX-LLM on MT-bench data on Intel GPUs.
|
||||
We use EAGLE-2 which have better performance than EAGLE-1
|
||||
|
||||
### 1. Install
|
||||
#### 1.1 Installation on Linux
|
||||
|
|
@ -28,8 +29,10 @@ conda create -n llm python=3.11
|
|||
conda activate llm
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
pip install eagle-llm
|
||||
git clone https://github.com/SafeAILab/EAGLE.git
|
||||
cd EAGLE
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
#### 1.2 Installation on Windows
|
||||
|
|
@ -42,10 +45,10 @@ pip install dpcpp-cpp-rt==2024.0.2 mkl-dpcpp==2024.0.0 onednn==2024.0.0
|
|||
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
git clone https://github.com/SafeAILab/EAGLE.git
|
||||
cd EAGLE
|
||||
pip install -r requirements.txt
|
||||
pip install transformers==4.36.2
|
||||
pip install gradio==3.50.2
|
||||
pip install eagle-llm
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### 2. Configures OneAPI environment variables for Linux
|
||||
|
|
@ -89,7 +92,7 @@ export ENABLE_SDP_FUSION=1
|
|||
### 4. Running Example
|
||||
You can test the speed of EAGLE speculative sampling with ipex-llm on MT-bench using the following command.
|
||||
```bash
|
||||
python -m evaluation.gen_ea_answer_llama2chat\
|
||||
python -m evaluation.gen_ea_answer_llama2chat_e2_ipex_optimize\
|
||||
--ea-model-path [path of EAGLE weight]\
|
||||
--base-model-path [path of the original model]\
|
||||
--enable-ipex-llm\
|
||||
|
|
|
|||
|
|
@ -34,105 +34,223 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
from accelerate.utils import set_seed
|
||||
set_seed(0)
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
import shortuuid
|
||||
from fastchat.llm_judge.common import load_questions
|
||||
from fastchat.model import get_conversation_template
|
||||
from tqdm import tqdm
|
||||
|
||||
from eagle.model.ea_model import EaModel
|
||||
from eagle.model.utils import *
|
||||
from eagle.model.utils import prepare_logits_processor, evaluate_posterior
|
||||
from eagle.model.kv_cache import initialize_past_key_values
|
||||
from eagle.model.choices import *
|
||||
from eagle.modeling_eagle import forward_with_tree_mask
|
||||
from ipex_llm import optimize_model
|
||||
from ipex_llm.transformers import AutoModelForCausalLM
|
||||
|
||||
def ea_forward(input_ids, model, tokenizer, tree_choices, logits_processor=None, max_steps=512):
|
||||
assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
|
||||
# Avoid modifying the input_ids in-place
|
||||
input_ids = input_ids.clone()
|
||||
model.ea_layer.reset_kv()
|
||||
class Timer:
|
||||
def __init__(self,name):
|
||||
self.name = name
|
||||
|
||||
if hasattr(model, "tree_choices") and model.tree_choices == tree_choices:
|
||||
tree_buffers = model.tree_buffers
|
||||
def __enter__(self):
|
||||
torch.xpu.synchronize()
|
||||
self.start = time.perf_counter()
|
||||
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
torch.xpu.synchronize()
|
||||
elapsed = time.perf_counter() - self.start
|
||||
print(f'{self.name} took {elapsed} seconds')
|
||||
|
||||
|
||||
def initialize_tree(input_ids, model, logits_processor):
|
||||
# outputs, orig, hidden_states = model(
|
||||
# input_ids, past_key_values=past_key_values, output_orig=True
|
||||
# )
|
||||
hidden_states, past_key_values = forward_with_tree_mask(model.base_model.model, input_ids=input_ids)
|
||||
orig = model.base_model.lm_head(hidden_states)
|
||||
|
||||
if logits_processor is not None:
|
||||
logits = orig[:, -1]
|
||||
logits = logits_processor(None, logits)
|
||||
probabilities = torch.nn.functional.softmax(logits, dim=1)
|
||||
token = torch.multinomial(probabilities, 1)
|
||||
else:
|
||||
tree_buffers = generate_tree_buffers(
|
||||
tree_choices, device=model.base_model.model.layers[-1].self_attn.q_proj.weight.device
|
||||
)
|
||||
tree_buffers["retrieve_indices_head"] = tree_buffers["retrieve_indices"].to(
|
||||
model.base_model.lm_head.weight.device)
|
||||
model.tree_buffers = tree_buffers
|
||||
model.tree_choices = tree_choices
|
||||
token = torch.argmax(orig[:, -1])
|
||||
token = token[None, None]
|
||||
input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)
|
||||
# Clone the output hidden states
|
||||
|
||||
# Initialize the past key and value states
|
||||
if hasattr(model, "past_key_values"):
|
||||
past_key_values = model.past_key_values
|
||||
past_key_values_data = model.past_key_values_data
|
||||
current_length_data = model.current_length_data
|
||||
# Reset the past key and value states
|
||||
current_length_data.zero_()
|
||||
else:
|
||||
(
|
||||
past_key_values,
|
||||
past_key_values_data,
|
||||
current_length_data,
|
||||
) = initialize_past_key_values(model.base_model)
|
||||
model.past_key_values = past_key_values
|
||||
model.past_key_values_data = past_key_values_data
|
||||
model.current_length_data = current_length_data
|
||||
draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(hidden_states, input_ids, model.base_model.lm_head,logits_processor)
|
||||
return draft_tokens, retrieve_indices,tree_mask,tree_position_ids, orig, hidden_states, token, past_key_values
|
||||
|
||||
input_len = input_ids.shape[1]
|
||||
reset_tree_mode(model)
|
||||
tree_logits, logits, hidden_state, sample_token = initialize_tree(
|
||||
input_ids, model, tree_buffers["tree_attn_mask"], past_key_values, logits_processor
|
||||
)
|
||||
new_token = 0
|
||||
|
||||
for idx in range(max_steps):
|
||||
candidates, cart_candidates_prob, tree_candidates = generate_candidates(
|
||||
tree_logits,
|
||||
tree_buffers["tree_indices"],
|
||||
tree_buffers["retrieve_indices"],
|
||||
sample_token,
|
||||
logits_processor
|
||||
)
|
||||
logits, hidden_state_new, outputs = tree_decoding(
|
||||
def tree_decoding(
|
||||
model,
|
||||
tree_candidates,
|
||||
past_key_values,
|
||||
tree_buffers["tree_position_ids"],
|
||||
tree_position_ids,
|
||||
input_ids,
|
||||
tree_buffers["retrieve_indices_head"],
|
||||
)
|
||||
best_candidate, accept_length, sample_p = evaluate_posterior(
|
||||
logits, candidates, logits_processor, cart_candidates_prob, tree_logits[2], tree_buffers["p_indices"],
|
||||
tree_candidates, tree_buffers["b_indices"]
|
||||
)
|
||||
input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs(
|
||||
retrieve_indices,
|
||||
attention_mask=None,
|
||||
tree_mask=None,
|
||||
):
|
||||
position_ids = tree_position_ids + input_ids.shape[1]
|
||||
|
||||
hidden_states, past_key_values = forward_with_tree_mask(model.base_model.model, input_ids=tree_candidates, past_key_values=past_key_values,
|
||||
position_ids=position_ids, tree_mask=tree_mask, attention_mask=attention_mask)
|
||||
|
||||
tree_logits = model.base_model.lm_head(hidden_states)
|
||||
|
||||
logits = tree_logits[0, retrieve_indices]
|
||||
|
||||
return logits, hidden_states, past_key_values
|
||||
|
||||
|
||||
def update_inference_inputs(
|
||||
input_ids,
|
||||
candidates,
|
||||
best_candidate,
|
||||
accept_length,
|
||||
tree_buffers["retrieve_indices"],
|
||||
retrieve_indices,
|
||||
logits_processor,
|
||||
logits,
|
||||
tree_logits,
|
||||
new_token,
|
||||
past_key_values_data,
|
||||
current_length_data,
|
||||
past_key_values,
|
||||
# current_length_data,
|
||||
model,
|
||||
hidden_state_new,
|
||||
sample_p
|
||||
):
|
||||
prev_input_len = input_ids.shape[1]
|
||||
# Map the best candidate indices to the original indices in the sequence
|
||||
select_indices = (
|
||||
retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len
|
||||
)
|
||||
# Append the tokens from the best candidate to the input sequence
|
||||
input_ids = torch.cat(
|
||||
[input_ids, candidates[None, best_candidate, : accept_length + 1].to(input_ids.device)], dim=-1
|
||||
)
|
||||
|
||||
new_kv = ()
|
||||
|
||||
for past_key_values_data in past_key_values:
|
||||
layer_kv = ()
|
||||
for korv in past_key_values_data:
|
||||
tgt = korv[:, :, select_indices, :]
|
||||
dst = korv[:, :, prev_input_len: prev_input_len + tgt.shape[-2], :]
|
||||
dst.copy_(tgt, non_blocking=True)
|
||||
layer_kv += (korv[:, :, : prev_input_len + tgt.shape[-2], :],)
|
||||
new_kv += (layer_kv,)
|
||||
|
||||
retrieve_hidden_state_new = hidden_state_new[:, retrieve_indices]
|
||||
accept_hidden_state_new = retrieve_hidden_state_new[:, best_candidate, : accept_length + 1]
|
||||
|
||||
prob = sample_p
|
||||
if logits_processor is not None:
|
||||
token = torch.multinomial(prob, 1)
|
||||
token = token[None]
|
||||
else:
|
||||
token = torch.argmax(prob)
|
||||
token = token[None, None]
|
||||
|
||||
draft_tokens, retrieve_indices,tree_mask,tree_position_ids = model.ea_layer.topK_genrate(accept_hidden_state_new,
|
||||
input_ids=torch.cat((input_ids, token.to(input_ids.device)), dim=1),
|
||||
head=model.base_model.lm_head,logits_processor=logits_processor)
|
||||
|
||||
|
||||
new_token += accept_length + 1
|
||||
|
||||
return input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, None, token, new_kv
|
||||
|
||||
|
||||
def eagenerate(
|
||||
model,
|
||||
input_ids,
|
||||
temperature=0.0,
|
||||
top_p=0.0,
|
||||
top_k=0.0,
|
||||
max_new_tokens=512,
|
||||
max_length=2048,
|
||||
log=False,
|
||||
is_llama3=False,
|
||||
|
||||
):
|
||||
if is_llama3:
|
||||
stop_token_id = model.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
||||
max_length=max_length-model.ea_layer.total_tokens-10
|
||||
|
||||
if temperature > 1e-5:
|
||||
logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
else:
|
||||
logits_processor = None
|
||||
#assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
|
||||
# Avoid modifying the input_ids in-place
|
||||
|
||||
padding=(torch.zeros(1,1,dtype=torch.long)-1).to(input_ids.device)
|
||||
input_ids = input_ids.clone()
|
||||
model.ea_layer.reset_kv()
|
||||
|
||||
input_len = input_ids.shape[1]
|
||||
# with Timer("initialize_tree"):
|
||||
draft_tokens, retrieve_indices,tree_mask,tree_position_ids, logits, hidden_state, sample_token, past_key_values = initialize_tree(
|
||||
input_ids, model, logits_processor
|
||||
)
|
||||
|
||||
new_token = 0
|
||||
|
||||
for idx in range(max_length):
|
||||
# with Timer("all"):
|
||||
draft_tokens=draft_tokens.to(input_ids.device)
|
||||
#with Timer("tree_decoding"):
|
||||
logits, hidden_state_new, past_key_values = tree_decoding(
|
||||
model,
|
||||
draft_tokens,
|
||||
past_key_values,
|
||||
tree_position_ids,
|
||||
input_ids,
|
||||
retrieve_indices,
|
||||
tree_mask=tree_mask,
|
||||
)
|
||||
|
||||
draft_tokens=torch.cat((draft_tokens,padding),dim=1)
|
||||
candidates=draft_tokens[0,retrieve_indices]
|
||||
# with Timer("evaluate_posterior"):
|
||||
best_candidate, accept_length, sample_p = evaluate_posterior(
|
||||
logits, candidates, logits_processor
|
||||
)
|
||||
|
||||
# print("new_token: ", (accept_length+1).item())
|
||||
# with Timer("update_inference_inputs"):
|
||||
input_ids, draft_tokens, retrieve_indices,tree_mask,tree_position_ids, new_token, hidden_state, sample_token, past_key_values = update_inference_inputs(
|
||||
input_ids,
|
||||
candidates,
|
||||
best_candidate,
|
||||
accept_length,
|
||||
retrieve_indices,
|
||||
logits_processor,
|
||||
new_token,
|
||||
past_key_values,
|
||||
# current_length_data,
|
||||
model,
|
||||
hidden_state,
|
||||
hidden_state_new,
|
||||
sample_p
|
||||
)
|
||||
if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
|
||||
|
||||
if is_llama3:
|
||||
if stop_token_id in input_ids[0, input_len:].tolist():
|
||||
break
|
||||
if new_token > 1024:
|
||||
|
||||
if model.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
|
||||
break
|
||||
if input_ids.shape[1] > 1960:
|
||||
if new_token > max_new_tokens:
|
||||
break
|
||||
if input_ids.shape[1] > max_length:
|
||||
break
|
||||
if not log:
|
||||
return input_ids
|
||||
else:
|
||||
return input_ids, new_token, idx
|
||||
|
||||
|
||||
|
|
@ -147,8 +265,8 @@ def run_eval(
|
|||
max_new_token,
|
||||
num_choices,
|
||||
temperature,
|
||||
tree_choices,
|
||||
enable_ipex_llm,
|
||||
args,
|
||||
):
|
||||
questions = load_questions(question_file, question_begin, question_end)
|
||||
shuffled_ids = [q["question_id"] for q in questions]
|
||||
|
|
@ -168,8 +286,8 @@ def run_eval(
|
|||
max_new_token,
|
||||
num_choices,
|
||||
temperature,
|
||||
tree_choices,
|
||||
enable_ipex_llm,
|
||||
args
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -184,34 +302,32 @@ def get_model_answers(
|
|||
max_new_token,
|
||||
num_choices,
|
||||
temperature,
|
||||
tree_choices,
|
||||
enable_ipex_llm,
|
||||
args
|
||||
):
|
||||
try:
|
||||
|
||||
model = EaModel.from_pretrained(
|
||||
base_model_path=base_model_path,
|
||||
ea_model_path=ea_model_path,
|
||||
#torch_dtype=torch.float16,
|
||||
torch_dtype=torch.float32,
|
||||
torch_dtype=torch.float16,
|
||||
# torch_dtype=torch.float32,
|
||||
low_cpu_mem_usage=True,
|
||||
# load_in_8bit=True,
|
||||
device_map="auto"
|
||||
)
|
||||
except ValueError:
|
||||
print("Using sequential device_map.")
|
||||
model = EaModel.from_pretrained(
|
||||
base_model_path=base_model_path,
|
||||
ea_model_path=ea_model_path,
|
||||
#torch_dtype=torch.float16,
|
||||
torch_dtype=torch.float32,
|
||||
low_cpu_mem_usage=True,
|
||||
# load_in_8bit=True,
|
||||
device_map="sequential"
|
||||
total_token=args.total_token,
|
||||
depth=args.depth,
|
||||
top_k=args.top_k,
|
||||
|
||||
)
|
||||
if enable_ipex_llm:
|
||||
# single line of change to enable ipex-llm
|
||||
model = optimize_model(model, low_bit='sym_int4', optimize_llm=False)
|
||||
model.to("xpu")
|
||||
# Use optimized int4 base model to replace base model in EaModel
|
||||
base_model = AutoModelForCausalLM.from_pretrained(base_model_path, load_in_4bit=True, use_cache=True,
|
||||
torch_dtype=torch.float16)
|
||||
model.base_model = base_model
|
||||
# Also optimize draft model in EaModel
|
||||
model.ea_layer = optimize_model(model.ea_layer)
|
||||
model = model.to("xpu")
|
||||
model.ea_layer.tree_mask_init = model.ea_layer.tree_mask_init.to("xpu")
|
||||
model.ea_layer.position_ids = model.ea_layer.position_ids.to("xpu")
|
||||
tokenizer = model.get_tokenizer()
|
||||
|
||||
if temperature > 1e-5:
|
||||
|
|
@ -246,15 +362,17 @@ def get_model_answers(
|
|||
inputs = tokenizer([prompt], return_tensors="pt").to("xpu")
|
||||
input_ids = inputs.input_ids
|
||||
|
||||
torch.xpu.synchronize()
|
||||
start_time = time.time()
|
||||
output_ids, new_token, idx = ea_forward(
|
||||
torch.as_tensor(input_ids),
|
||||
model,
|
||||
tokenizer,
|
||||
tree_choices,
|
||||
logits_processor,
|
||||
)
|
||||
|
||||
output_ids, new_token, idx = eagenerate(
|
||||
model,
|
||||
torch.as_tensor(input_ids),
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_token,
|
||||
log=True
|
||||
)
|
||||
torch.xpu.synchronize()
|
||||
total_time = time.time() - start_time
|
||||
output_ids = output_ids[0][len(input_ids[0]):]
|
||||
# be consistent with the template's stop_token_ids
|
||||
|
|
@ -311,14 +429,16 @@ def get_model_answers(
|
|||
input_ids = inputs.input_ids
|
||||
|
||||
try:
|
||||
torch.xpu.synchronize()
|
||||
start_time = time.time()
|
||||
output_ids, new_token, idx = ea_forward(
|
||||
torch.as_tensor(input_ids),
|
||||
output_ids, new_token, idx = eagenerate(
|
||||
model,
|
||||
tokenizer,
|
||||
tree_choices,
|
||||
logits_processor,
|
||||
torch.as_tensor(input_ids),
|
||||
temperature=temperature,
|
||||
max_new_tokens=max_new_token,
|
||||
log=True,
|
||||
)
|
||||
torch.xpu.synchronize()
|
||||
total_time = time.time() - start_time
|
||||
output_ids = output_ids[0][len(input_ids[0]):]
|
||||
|
||||
|
|
@ -417,6 +537,22 @@ if __name__ == "__main__":
|
|||
default=1024,
|
||||
help="The maximum number of new generated tokens.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--total-token",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The total number of nodes in the draft tree",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--depth",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=5,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-choices",
|
||||
type=int,
|
||||
|
|
@ -426,7 +562,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
default=0.0,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -443,7 +579,7 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
|
||||
args.model_id = args.model_id + "-temperature-" + str(args.temperature)
|
||||
args.tree_choices = eval(args.tree_choices)
|
||||
# args.tree_choices = eval(args.tree_choices)
|
||||
|
||||
question_file = f"data/{args.bench_name}/question.jsonl"
|
||||
if args.answer_file:
|
||||
|
|
@ -464,8 +600,8 @@ if __name__ == "__main__":
|
|||
args.max_new_token,
|
||||
args.num_choices,
|
||||
args.temperature,
|
||||
args.tree_choices,
|
||||
args.enable_ipex_llm,
|
||||
args
|
||||
)
|
||||
|
||||
reorg_answer_file(answer_file)
|
||||
Loading…
Reference in a new issue